Unverified Commit 52c48f74 authored by 乔占卫's avatar 乔占卫 Committed by GitHub
Browse files

Merge pull request #97 from Baoqi/pre_post_statement

close #85 Add Pre/Post Statement support in SQL Task
parents 2aca5ff5 ae6ae039
Loading
Loading
Loading
Loading
+42 −0
Original line number Diff line number Diff line
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package cn.escheduler.common.task.sql;

import cn.escheduler.common.process.Property;

import java.util.Map;

/**
 * Used to contains both prepared sql string and its to-be-bind parameters
 */
public class SqlBinds {
    private final String sql;
    private final Map<Integer, Property> paramsMap;

    public SqlBinds(String sql, Map<Integer, Property> paramsMap) {
        this.sql = sql;
        this.paramsMap = paramsMap;
    }

    public String getSql() {
        return sql;
    }

    public Map<Integer, Property> getParamsMap() {
        return paramsMap;
    }
}
+25 −0
Original line number Diff line number Diff line
@@ -64,6 +64,14 @@ public class SqlParameters extends AbstractParameters {
     * SQL connection parameters
     */
    private String connParams;
    /**
     * Pre Statements
     */
    private List<String> preStatements;
    /**
     * Post Statements
     */
    private List<String> postStatements;

    public String getType() {
        return type;
@@ -121,6 +129,21 @@ public class SqlParameters extends AbstractParameters {
        this.connParams = connParams;
    }

    public List<String> getPreStatements() {
        return preStatements;
    }

    public void setPreStatements(List<String> preStatements) {
        this.preStatements = preStatements;
    }

    public List<String> getPostStatements() {
        return postStatements;
    }

    public void setPostStatements(List<String> postStatements) {
        this.postStatements = postStatements;
    }

    @Override
    public boolean checkParameters() {
@@ -142,6 +165,8 @@ public class SqlParameters extends AbstractParameters {
                ", udfs='" + udfs + '\'' +
                ", showType='" + showType + '\'' +
                ", connParams='" + connParams + '\'' +
                ", preStatements=" + preStatements +
                ", postStatements=" + postStatements +
                '}';
    }
}
+86 −51
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ import cn.escheduler.common.enums.UdfType;
import cn.escheduler.common.job.db.*;
import cn.escheduler.common.process.Property;
import cn.escheduler.common.task.AbstractParameters;
import cn.escheduler.common.task.sql.SqlBinds;
import cn.escheduler.common.task.sql.SqlParameters;
import cn.escheduler.common.task.sql.SqlType;
import cn.escheduler.common.utils.CollectionUtils;
@@ -48,6 +49,7 @@ import java.sql.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 *  sql task
@@ -131,11 +133,17 @@ public class SqlTask extends AbstractTask {
                        Class.forName(Constants.JDBC_SQLSERVER_CLASS_NAME);
                    }

                    Map<Integer,Property> sqlParamMap =  new HashMap<Integer,Property>();
                    StringBuilder sqlBuilder = new StringBuilder();

                    // ready to execute SQL and parameter entity Map
                    setSqlAndSqlParamsMap(sqlBuilder,sqlParamMap);
                    SqlBinds mainSqlBinds = getSqlAndSqlParamsMap(sqlParameters.getSql());
                    List<SqlBinds> preStatementSqlBinds = Optional.ofNullable(sqlParameters.getPreStatements()).orElse(new ArrayList<>())
                            .stream()
                            .map(this::getSqlAndSqlParamsMap)
                            .collect(Collectors.toList());
                    List<SqlBinds> postStatementSqlBinds = Optional.ofNullable(sqlParameters.getPostStatements()).orElse(new ArrayList<>())
                            .stream()
                            .map(this::getSqlAndSqlParamsMap)
                            .collect(Collectors.toList());

                    if(EnumUtils.isValidEnum(UdfType.class, sqlParameters.getType()) && StringUtils.isNotEmpty(sqlParameters.getUdfs())){
                        List<UdfFunc> udfFuncList = processDao.queryUdfFunListByids(sqlParameters.getUdfs());
@@ -143,7 +151,7 @@ public class SqlTask extends AbstractTask {
                    }

                    // execute sql task
                    con = executeFuncAndSql(baseDataSource,sqlBuilder.toString(),sqlParamMap,createFuncs);
                    con = executeFuncAndSql(baseDataSource, mainSqlBinds, preStatementSqlBinds, postStatementSqlBinds, createFuncs);

                } finally {
                    if (con != null) {
@@ -162,9 +170,9 @@ public class SqlTask extends AbstractTask {
     *  ready to execute SQL and parameter entity Map
     * @return
     */
    private void setSqlAndSqlParamsMap(StringBuilder sqlBuilder,Map<Integer,Property> sqlParamsMap) {

        String sql =  sqlParameters.getSql();
    private SqlBinds getSqlAndSqlParamsMap(String sql) {
        Map<Integer,Property> sqlParamsMap =  new HashMap<>();
        StringBuilder sqlBuilder = new StringBuilder();

        // find process instance by task id
        ProcessInstance processInstance = processDao.findProcessInstanceByTaskId(taskProps.getTaskInstId());
@@ -178,7 +186,7 @@ public class SqlTask extends AbstractTask {
        // spell SQL according to the final user-defined variable
        if(paramsMap == null){
            sqlBuilder.append(sql);
            return;
            return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
        }

        // special characters need to be escaped, ${} needs to be escaped
@@ -191,6 +199,7 @@ public class SqlTask extends AbstractTask {

        // print repalce sql
        printReplacedSql(sql,formatSql,rgex,sqlParamsMap);
        return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
    }

    @Override
@@ -201,10 +210,16 @@ public class SqlTask extends AbstractTask {
    /**
     *  execute sql
     * @param baseDataSource
     * @param sql
     * @param params
     * @param mainSqlBinds
     * @param preStatementsBinds
     * @param postStatementsBinds
     * @param createFuncs
     */
    public Connection executeFuncAndSql(BaseDataSource baseDataSource, String sql, Map<Integer,Property> params, List<String> createFuncs){
    public Connection executeFuncAndSql(BaseDataSource baseDataSource,
                                        SqlBinds mainSqlBinds,
                                        List<SqlBinds> preStatementsBinds,
                                        List<SqlBinds> postStatementsBinds,
                                        List<String> createFuncs){
        Connection connection = null;
        try {

@@ -223,25 +238,24 @@ public class SqlTask extends AbstractTask {
                        baseDataSource.getUser(), baseDataSource.getPassword());
            }

            Statement  funcStmt = connection.createStatement();
            // create temp function
            if (createFuncs != null) {
            if (CollectionUtils.isNotEmpty(createFuncs)) {
                try (Statement  funcStmt = connection.createStatement()) {
                    for (String createFunc : createFuncs) {
                        logger.info("hive create function sql: {}", createFunc);
                        funcStmt.execute(createFunc);
                    }
                }

            PreparedStatement  stmt = connection.prepareStatement(sql);
            if(taskProps.getTaskTimeoutStrategy() == TaskTimeoutStrategy.FAILED || taskProps.getTaskTimeoutStrategy() == TaskTimeoutStrategy.WARNFAILED){
                stmt.setQueryTimeout(taskProps.getTaskTimeout());
            }
            if(params != null){
                for(Integer key : params.keySet()){
                    Property prop = params.get(key);
                    ParameterUtils.setInParameter(key,stmt,prop.getType(),prop.getValue());

            for (SqlBinds sqlBind: preStatementsBinds) {
                try (PreparedStatement stmt = prepareStatementAndBind(connection, sqlBind)) {
                    int result = stmt.executeUpdate();
                    logger.info("pre statement execute result: " + result + ", for sql: "  + sqlBind.getSql());
                }
            }

            try (PreparedStatement  stmt = prepareStatementAndBind(connection, mainSqlBinds)) {
                // decide whether to executeQuery or executeUpdate based on sqlType
                if (sqlParameters.getSqlType() == SqlType.QUERY.ordinal()) {
                    // query statements need to be convert to JsonArray and inserted into Alert to send
@@ -276,13 +290,34 @@ public class SqlTask extends AbstractTask {
                    int result = stmt.executeUpdate();
                    exitStatusCode = 0;
                }
            }

            for (SqlBinds sqlBind: postStatementsBinds) {
                try (PreparedStatement stmt = prepareStatementAndBind(connection, sqlBind)) {
                    int result = stmt.executeUpdate();
                    logger.info("post statement execute result: " + result + ", for sql: "  + sqlBind.getSql());
                }
            }
        } catch (Exception e) {
            logger.error(e.getMessage(),e);
        }
        return connection;
    }

    private PreparedStatement prepareStatementAndBind(Connection connection, SqlBinds sqlBinds) throws Exception {
        PreparedStatement  stmt = connection.prepareStatement(sqlBinds.getSql());
        if(taskProps.getTaskTimeoutStrategy() == TaskTimeoutStrategy.FAILED || taskProps.getTaskTimeoutStrategy() == TaskTimeoutStrategy.WARNFAILED){
            stmt.setQueryTimeout(taskProps.getTaskTimeout());
        }
        Map<Integer, Property> params = sqlBinds.getParamsMap();
        if(params != null){
            for(Integer key : params.keySet()){
                Property prop = params.get(key);
                ParameterUtils.setInParameter(key,stmt,prop.getType(),prop.getValue());
            }
        }
        return stmt;
    }

    /**
     *  send mail as an attachment
+143 −0
Original line number Diff line number Diff line
<template>
  <div class="statement-list-model">
    <div class="select-listpp"
         v-for="(item,$index) in localStatementList"
         :key="item.id"
         @click="_getIndex($index)">
      <x-input
        :disabled="isDetails"
        type="textarea"
        resize="none"
        :autosize="{minRows:1}"
        v-model="localStatementList[$index]"
        @on-blur="_verifProp()"
        style="width: 525px;">
      </x-input>
      <span class="lt-add">
        <a href="javascript:" style="color:red;" @click="!isDetails && _removeStatement($index)" >
          <i class="iconfont" :class="_isDetails" data-toggle="tooltip" :title="$t('delete')" >&#xe611;</i>
        </a>
      </span>
      <span class="add" v-if="$index === (localStatementList.length - 1)">
        <a href="javascript:" @click="!isDetails && _addStatement()" >
          <i class="iconfont" :class="_isDetails" data-toggle="tooltip" :title="$t('Add')">&#xe636;</i>
        </a>
      </span>
    </div>
    <span class="add" v-if="!localStatementList.length">
      <a href="javascript:" @click="!isDetails && _addStatement()" >
        <i class="iconfont" :class="_isDetails" data-toggle="tooltip" :title="$t('Add')">&#xe636;</i>
      </a>
    </span>
  </div>
</template>
<script>
  import _ from 'lodash'
  import i18n from '@/module/i18n'
  import disabledState from '@/module/mixin/disabledState'
  export default {
    name: 'user-def-statements',
    data () {
      return {
        // Increased data
        localStatementList: [],
        // Current execution index
        localStatementIndex: null
      }
    },
    mixins: [disabledState],
    props: {
      statementList: Array
    },
    methods: {
      /**
       * Current index
       */
      _getIndex (index) {
        this.localStatementIndex = index
      },
      /**
       * delete item
       */
      _removeStatement (index) {
        this.localStatementList.splice(index, 1)
        this._verifProp('value')
      },
      /**
       * add
       */
      _addStatement () {
        this.localStatementList.push('')
      },
      /**
       * blur verification
       */
      _handleValue () {
        this._verifProp('value')
      },
      /**
       * Verify that the value exists or is empty
       */
      _verifProp (type) {
        let arr = []
        let flag = true
        _.map(this.localStatementList, v => {
          arr.push(v)
          if (!v) {
            flag = false
          }
        })
        if (!flag) {
          if (!type) {
            this.$message.warning(`${i18n.$t('Statement cannot be empty')}`)
          }
          return false
        }

        this.$emit('on-statement-list', _.cloneDeep(this.localStatementList))
        return true
      }
    },
    watch: {
      // Monitor data changes
      statementList () {
        this.localStatementList = this.statementList
      }
    },
    created () {
      this.localStatementList = this.statementList
    },
    mounted () {
    },
    components: { }
  }
</script>

<style lang="scss" rel="stylesheet/scss">
  .statement-list-model {
    .select-listpp {
      margin-bottom: 6px;
      .lt-add {
        padding-left: 4px;
        a {
          .iconfont {
            font-size: 18px;
            vertical-align: middle;
            margin-bottom: -2px;
            display: inline-block;
          }
        }
      }
    }
    .add {
      a {
        .iconfont {
          font-size: 18px;
          vertical-align: middle;
          display: inline-block;
          margin-top: 1px;
        }
      }
    }
  }
</style>
+54 −3
Original line number Diff line number Diff line
@@ -72,6 +72,26 @@
        </m-local-params>
      </div>
    </m-list-box>
    <m-list-box>
      <div slot="text">{{$t('Pre Statement')}}</div>
      <div slot="content">
        <m-statement-list
          ref="refPreStatements"
          @on-statement-list="_onPreStatements"
          :statement-list="preStatements">
        </m-statement-list>
      </div>
    </m-list-box>
    <m-list-box>
      <div slot="text">{{$t('Post Statement')}}</div>
      <div slot="content">
        <m-statement-list
          ref="refPostStatements"
          @on-statement-list="_onPostStatements"
          :statement-list="postStatements">
        </m-statement-list>
      </div>
    </m-list-box>
  </div>
</template>
<script>
@@ -82,6 +102,7 @@
  import mSqlType from './_source/sqlType'
  import mDatasource from './_source/datasource'
  import mLocalParams from './_source/localParams'
  import mStatementList from './_source/statementList'
  import disabledState from '@/module/mixin/disabledState'
  import codemirror from '@/conf/home/pages/resource/pages/file/pages/_source/codemirror'

@@ -108,7 +129,11 @@
        // Form/attachment
        showType: ['TABLE'],
        // Sql parameter
        connParams: ''
        connParams: '',
        // Pre statements
        preStatements: [],
        // Post statements
        postStatements: []
      }
    },
    mixins: [disabledState],
@@ -141,6 +166,18 @@
        this.type = o.type
        this.rtDatasource = o.datasource
      },
      /**
       * return pre statements
       */
      _onPreStatements (a) {
        this.preStatements = a
      },
      /**
       * return post statements
       */
      _onPostStatements (a) {
        this.postStatements = a
      },
      /**
       * verification
       */
@@ -167,6 +204,16 @@
          return false
        }

        // preStatements Subcomponent verification
        if (!this.$refs.refPreStatements._verifProp()) {
          return false
        }

        // postStatements Subcomponent verification
        if (!this.$refs.refPostStatements._verifProp()) {
          return false
        }

        // storage
        this.$emit('on-params', {
          type: this.type,
@@ -187,7 +234,9 @@
            }
          })(),
          localParams: this.localParams,
          connParams: this.connParams
          connParams: this.connParams,
          preStatements: this.preStatements,
          postStatements: this.postStatements
        })
        return true
      },
@@ -245,6 +294,8 @@
        this.connParams = o.params.connParams || ''
        this.localParams = o.params.localParams || []
        this.showType = o.params.showType.split(',') || []
        this.preStatements = o.params.preStatements || []
        this.postStatements = o.params.postStatements || []
      }
    },
    mounted () {
@@ -262,6 +313,6 @@
      }
    },
    computed: {},
    components: { mListBox, mDatasource, mLocalParams, mUdfs, mSqlType }
    components: { mListBox, mDatasource, mLocalParams, mUdfs, mSqlType, mStatementList }
  }
</script>
Loading