Commit ac440497 authored by terrymanu's avatar terrymanu
Browse files

for #2084, add ParameterMarkerExpressionSegment

parent 89b30f41
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -27,7 +27,8 @@ import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ParameterMarkerExpressionSegment;

import java.util.Map;

@@ -51,7 +52,10 @@ public final class AssignmentExtractor implements OptionalSQLSegmentExtractor {
        }
        Optional<ColumnSegment> columnSegment = columnExtractor.extract((ParserRuleContext) assignmentNode.get().getChild(0), parameterMarkerIndexes);
        Preconditions.checkState(columnSegment.isPresent());
        LiteralExpressionSegment expressionSegment = expressionExtractor.extractLiteralExpressionSegment((ParserRuleContext) assignmentNode.get().getChild(2), parameterMarkerIndexes);
        Optional<ParameterMarkerExpressionSegment> parameterMarkerExpressionSegment = expressionExtractor.extractParameterMarkerExpressionSegment(
                (ParserRuleContext) assignmentNode.get().getChild(2), parameterMarkerIndexes);
        ExpressionSegment expressionSegment = parameterMarkerExpressionSegment.isPresent()
                ? parameterMarkerExpressionSegment.get() : expressionExtractor.extractLiteralExpressionSegment((ParserRuleContext) assignmentNode.get().getChild(2));
        return Optional.of(new AssignmentSegment(columnSegment.get(), expressionSegment));
    }
}
+22 −10
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.OptionalSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.select.SubqueryExtractor;
@@ -26,6 +27,7 @@ import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.core.util.NumberUtil;

import java.util.Map;
@@ -46,26 +48,36 @@ public final class ExpressionExtractor implements OptionalSQLSegmentExtractor {
    private ExpressionSegment extractExpression(final ParserRuleContext expressionNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<ParserRuleContext> parameterMarkerNode = ExtractorUtils.findSingleNodeFromFirstDescendant(expressionNode, RuleName.PARAMETER_MARKER);
        if (parameterMarkerNode.isPresent()) {
            return extractLiteralExpressionSegment(parameterMarkerNode.get(), parameterMarkerIndexes);
            Optional<ParameterMarkerExpressionSegment> result = extractParameterMarkerExpressionSegment(parameterMarkerNode.get(), parameterMarkerIndexes);
            Preconditions.checkState(result.isPresent());
            return result.get();
        }
        Optional<ParserRuleContext> literalsNode = ExtractorUtils.findSingleNodeFromFirstDescendant(expressionNode, RuleName.LITERALS);
        return literalsNode.isPresent() ? extractLiteralExpressionSegment(literalsNode.get(), parameterMarkerIndexes) : extractCommonExpressionSegment(expressionNode);
        return literalsNode.isPresent() ? extractLiteralExpressionSegment(literalsNode.get()) : extractCommonExpressionSegment(expressionNode);
    }
    
    /**
     * Extract literal expression segment.
     * Extract parameter marker expression segment.
     *
     * @param parameterMarkerIndexes parameter marker indexes
     * @param expressionNode expression node
     * @return common expression segment
     * @return parameter marker expression segment
     */
    public LiteralExpressionSegment extractLiteralExpressionSegment(final ParserRuleContext expressionNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        LiteralExpressionSegment result = new LiteralExpressionSegment(expressionNode.getStart().getStartIndex(), expressionNode.getStop().getStopIndex());
    public Optional<ParameterMarkerExpressionSegment> extractParameterMarkerExpressionSegment(final ParserRuleContext expressionNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<ParserRuleContext> parameterMarkerNode = ExtractorUtils.findFirstChildNode(expressionNode, RuleName.PARAMETER_MARKER);
        if (parameterMarkerNode.isPresent()) {
            result.setParameterMarkerIndex(parameterMarkerIndexes.get(parameterMarkerNode.get()));
            return result;
        return parameterMarkerNode.isPresent()
                ? Optional.of(new ParameterMarkerExpressionSegment(expressionNode.getStart().getStartIndex(), expressionNode.getStop().getStopIndex(),
                parameterMarkerIndexes.get(parameterMarkerNode.get()))) : Optional.<ParameterMarkerExpressionSegment>absent();
    }
    
    /**
     * Extract literal expression segment.
     *
     * @param expressionNode expression node
     * @return literal expression segment
     */
    public LiteralExpressionSegment extractLiteralExpressionSegment(final ParserRuleContext expressionNode) {
        LiteralExpressionSegment result = new LiteralExpressionSegment(expressionNode.getStart().getStartIndex(), expressionNode.getStop().getStopIndex());
        Optional<ParserRuleContext> numberLiteralsNode = ExtractorUtils.findFirstChildNode(expressionNode, RuleName.NUMBER_LITERALS);
        if (numberLiteralsNode.isPresent()) {
            result.setLiterals(NumberUtil.getExactlyNumber(numberLiteralsNode.get().getText(), 10));
+11 −5
Original line number Diff line number Diff line
@@ -24,7 +24,8 @@ import org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.ExpressionE
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.InsertValuesSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ParameterMarkerExpressionSegment;

import java.util.Collection;
import java.util.Collections;
@@ -48,15 +49,20 @@ public final class InsertValuesExtractor implements CollectionSQLSegmentExtracto
        }
        Collection<InsertValuesSegment> result = new LinkedList<>();
        for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(insertValuesClauseNode.get(), RuleName.ASSIGNMENT_VALUES)) {
            result.add(new InsertValuesSegment(extractCommonExpressionSegments(each, parameterMarkerIndexes)));
            result.add(new InsertValuesSegment(extractExpressionSegments(each, parameterMarkerIndexes)));
        }
        return result;
    }
    
    private Collection<LiteralExpressionSegment> extractCommonExpressionSegments(final ParserRuleContext assignmentValuesNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Collection<LiteralExpressionSegment> result = new LinkedList<>();
    private Collection<ExpressionSegment> extractExpressionSegments(final ParserRuleContext assignmentValuesNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Collection<ExpressionSegment> result = new LinkedList<>();
        for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(assignmentValuesNode, RuleName.ASSIGNMENT_VALUE)) {
            result.add(expressionExtractor.extractLiteralExpressionSegment(each, parameterMarkerIndexes));
            Optional<ParameterMarkerExpressionSegment> parameterMarkerExpressionSegment = expressionExtractor.extractParameterMarkerExpressionSegment(each, parameterMarkerIndexes);
            if (parameterMarkerExpressionSegment.isPresent()) {
                result.add(parameterMarkerExpressionSegment.get());
            } else {
                result.add(expressionExtractor.extractLiteralExpressionSegment(each));
            }
        }
        return result;
    }
+2 −2
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ package org.apache.shardingsphere.core.parse.antlr.filler.encrypt.dml.insert;

import org.apache.shardingsphere.core.parse.antlr.filler.api.SQLSegmentFiller;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.InsertValuesSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.SQLStatement;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.parse.old.parser.context.insertvalue.InsertValue;
@@ -45,7 +45,7 @@ public final class EncryptInsertValuesFiller implements SQLSegmentFiller<InsertV
    
    private InsertValue getInsertValue(final InsertValuesSegment sqlSegment, final String sql) {
        List<SQLExpression> columnValues = new LinkedList<>();
        for (LiteralExpressionSegment each : sqlSegment.getValues()) {
        for (ExpressionSegment each : sqlSegment.getValues()) {
            SQLExpression sqlExpression = each.getSQLExpression(sql);
            columnValues.add(sqlExpression);
        }
+2 −2
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ import org.apache.shardingsphere.core.parse.antlr.filler.api.ShardingTableMetaDa
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.SQLStatement;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.dml.UpdateStatement;
@@ -107,7 +107,7 @@ public final class ShardingSetAssignmentsFiller implements SQLSegmentFiller<SetA
        return insertStatement.getColumnNames().size() - assistedQueryColumnCount;
    }
    
    private SQLExpression getColumnValue(final InsertStatement insertStatement, final AndCondition andCondition, final String columnName, final LiteralExpressionSegment expressionSegment) {
    private SQLExpression getColumnValue(final InsertStatement insertStatement, final AndCondition andCondition, final String columnName, final ExpressionSegment expressionSegment) {
        SQLExpression result = expressionSegment.getSQLExpression(insertStatement.getLogicSQL());
        String tableName = insertStatement.getTables().getSingleTableName();
        fillShardingCondition(andCondition, columnName, tableName, result);
Loading