Commit b784911c authored by terrymanu's avatar terrymanu
Browse files

for #2084, PredicateExtractor implements OptionalSQLSegmentExtractor

parent f904789a
Loading
Loading
Loading
Loading
+21 −26
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.constant.ShardingOperator;
import org.apache.shardingsphere.core.parse.antlr.constant.LogicalOperator;
import org.apache.shardingsphere.core.parse.antlr.constant.Paren;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.OptionalSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.impl.common.column.ColumnExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
@@ -46,30 +47,24 @@ import java.util.Map;
 *
 * @author duhongjun
 */
public final class PredicateExtractor {
public final class PredicateExtractor implements OptionalSQLSegmentExtractor {
    
    private final ExpressionExtractor expressionExtractor = new ExpressionExtractor();
    
    private final ColumnExtractor columnExtractor = new ColumnExtractor();
    
    /**
     * Extract.
     *
     * @param parameterMarkerIndexes parameter marker indexes
     * @param exprNode expression node of AST
     * @return or condition
     */
    public Optional<OrPredicateSegment> extract(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext exprNode) {
        return extractConditionInternal(parameterMarkerIndexes, exprNode);
    @Override
    public Optional<OrPredicateSegment> extract(final ParserRuleContext exprNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        return extractConditionInternal(exprNode, parameterMarkerIndexes);
    }
    
    private Optional<OrPredicateSegment> extractConditionInternal(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext exprNode) {
    private Optional<OrPredicateSegment> extractConditionInternal(final ParserRuleContext exprNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<Integer> index = getLogicalOperatorIndex(exprNode);
        if (!index.isPresent()) {
            return extractConditionForParen(parameterMarkerIndexes, exprNode);
            return extractConditionForParen(exprNode, parameterMarkerIndexes);
        }
        Optional<OrPredicateSegment> leftOrCondition = extractConditionInternal(parameterMarkerIndexes, (ParserRuleContext) exprNode.getChild(index.get() - 1));
        Optional<OrPredicateSegment> rightOrCondition = extractConditionInternal(parameterMarkerIndexes, (ParserRuleContext) exprNode.getChild(index.get() + 1));
        Optional<OrPredicateSegment> leftOrCondition = extractConditionInternal((ParserRuleContext) exprNode.getChild(index.get() - 1), parameterMarkerIndexes);
        Optional<OrPredicateSegment> rightOrCondition = extractConditionInternal((ParserRuleContext) exprNode.getChild(index.get() + 1), parameterMarkerIndexes);
        if (leftOrCondition.isPresent() && rightOrCondition.isPresent()) {
            return Optional.of(mergePredicate(leftOrCondition.get(), rightOrCondition.get(), exprNode.getChild(index.get()).getText()));
        }
@@ -85,15 +80,15 @@ public final class PredicateExtractor {
        return Optional.absent();
    }
    
    private Optional<OrPredicateSegment> extractConditionForParen(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext exprNode) {
    private Optional<OrPredicateSegment> extractConditionForParen(final ParserRuleContext exprNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<Integer> index = getLeftParenIndex(exprNode);
        if (index.isPresent()) {
            if (RuleName.EXPR.getName().equals(exprNode.getChild(index.get() + 1).getClass().getSimpleName())) {
                return extractConditionInternal(parameterMarkerIndexes, (ParserRuleContext) exprNode.getChild(index.get() + 1));
                return extractConditionInternal((ParserRuleContext) exprNode.getChild(index.get() + 1), parameterMarkerIndexes);
            }
            return Optional.absent();
        }
        Optional<PredicateSegment> predicate = extractPredicate(parameterMarkerIndexes, exprNode);
        Optional<PredicateSegment> predicate = extractPredicate(exprNode, parameterMarkerIndexes);
        if (!predicate.isPresent()) {
            return Optional.absent();
        }
@@ -113,8 +108,8 @@ public final class PredicateExtractor {
        return Optional.absent();
    }
    
    private Optional<PredicateSegment> extractPredicate(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext exprNode) {
        Optional<PredicateSegment> result = extractComparisonPredicate(parameterMarkerIndexes, exprNode);
    private Optional<PredicateSegment> extractPredicate(final ParserRuleContext exprNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<PredicateSegment> result = extractComparisonPredicate(exprNode, parameterMarkerIndexes);
        if (result.isPresent()) {
            return result;
        }
@@ -127,13 +122,13 @@ public final class PredicateExtractor {
            return Optional.absent();
        }
        if (5 == predicateNode.get().getChildCount() && "BETWEEN".equalsIgnoreCase(predicateNode.get().getChild(1).getText())) {
            result = extractBetweenPredicate(parameterMarkerIndexes, predicateNode.get(), column.get());
            result = extractBetweenPredicate(predicateNode.get(), parameterMarkerIndexes, column.get());
            if (result.isPresent()) {
                return result;
            }
        }
        if (predicateNode.get().getChildCount() >= 5 && "IN".equalsIgnoreCase(predicateNode.get().getChild(1).getText())) {
            result = extractInPredicate(parameterMarkerIndexes, predicateNode.get(), column.get());
            result = extractInPredicate(predicateNode.get(), parameterMarkerIndexes, column.get());
            if (result.isPresent()) {
                return result;
            }
@@ -141,7 +136,7 @@ public final class PredicateExtractor {
        return Optional.absent();
    }
    
    private Optional<PredicateSegment> extractComparisonPredicate(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext exprNode) {
    private Optional<PredicateSegment> extractComparisonPredicate(final ParserRuleContext exprNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<ParserRuleContext> comparisonOperatorNode = ExtractorUtils.findFirstChildNode(exprNode, RuleName.COMPARISON_OPERATOR);
        if (!comparisonOperatorNode.isPresent()) {
            return Optional.absent();
@@ -162,7 +157,7 @@ public final class PredicateExtractor {
                new CompareValueExpressionSegment(sqlExpression.get(), compareOperator), booleanPrimaryNode.getStop().getStopIndex())) : Optional.<PredicateSegment>absent();
    }
    
    private Optional<PredicateSegment> extractBetweenPredicate(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext predicateNode, final ColumnSegment column) {
    private Optional<PredicateSegment> extractBetweenPredicate(final ParserRuleContext predicateNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ColumnSegment column) {
        Optional<? extends ExpressionSegment> beginSQLExpression = expressionExtractor.extract(parameterMarkerIndexes, (ParserRuleContext) predicateNode.getChild(2));
        Optional<? extends ExpressionSegment> endSQLExpression = expressionExtractor.extract(parameterMarkerIndexes, (ParserRuleContext) predicateNode.getChild(4));
        return beginSQLExpression.isPresent() && endSQLExpression.isPresent()
@@ -171,13 +166,13 @@ public final class PredicateExtractor {
                : Optional.<PredicateSegment>absent();
    }
    
    private Optional<PredicateSegment> extractInPredicate(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext predicateNode, final ColumnSegment column) {
        Collection<ExpressionSegment> sqlExpressions = extractExpressionSegments(parameterMarkerIndexes, predicateNode);
    private Optional<PredicateSegment> extractInPredicate(final ParserRuleContext predicateNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ColumnSegment column) {
        Collection<ExpressionSegment> sqlExpressions = extractExpressionSegments(predicateNode, parameterMarkerIndexes);
        return sqlExpressions.isEmpty() ? Optional.<PredicateSegment>absent()
                : Optional.of(new PredicateSegment(column, ShardingOperator.IN.name(), new InValueExpressionSegment(sqlExpressions), predicateNode.getStop().getStopIndex()));
    }
    
    private Collection<ExpressionSegment> extractExpressionSegments(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext predicateNode) {
    private Collection<ExpressionSegment> extractExpressionSegments(final ParserRuleContext predicateNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        List<ExpressionSegment> result = new LinkedList<>();
        for (int i = 3; i < predicateNode.getChildCount(); i++) {
            if (RuleName.EXPR.getName().equals(predicateNode.getChild(i).getClass().getSimpleName())) {
+5 −5
Original line number Diff line number Diff line
@@ -44,8 +44,8 @@ public final class WhereExtractor implements OptionalSQLSegmentExtractor {
        result.setParameterCount(parameterMarkerIndexes.size());
        Optional<ParserRuleContext> whereNode = ExtractorUtils.findFirstChildNodeNoneRecursive(ancestorNode, RuleName.WHERE_CLAUSE);
        if (whereNode.isPresent()) {
            setPropertiesForRevert(result, parameterMarkerIndexes, whereNode.get());
            Optional<OrPredicateSegment> orConditionSegment = extractOrConditionSegment(parameterMarkerIndexes, whereNode.get());
            setPropertiesForRevert(result, whereNode.get(), parameterMarkerIndexes);
            Optional<OrPredicateSegment> orConditionSegment = extractOrConditionSegment(whereNode.get(), parameterMarkerIndexes);
            if (orConditionSegment.isPresent()) {
                result.getOrPredicate().getAndPredicates().addAll(orConditionSegment.get().getAndPredicates());
            }
@@ -53,7 +53,7 @@ public final class WhereExtractor implements OptionalSQLSegmentExtractor {
        return Optional.of(result);
    }
    
    private void setPropertiesForRevert(final WhereSegment whereSegment, final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext whereNode) {
    private void setPropertiesForRevert(final WhereSegment whereSegment, final ParserRuleContext whereNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        whereSegment.setWhereStartIndex(whereNode.getStart().getStartIndex());
        whereSegment.setWhereStopIndex(whereNode.getStop().getStopIndex());
        if (parameterMarkerIndexes.isEmpty()) {
@@ -68,8 +68,8 @@ public final class WhereExtractor implements OptionalSQLSegmentExtractor {
        whereSegment.setWhereParameterEndIndex(whereParameterStartIndex + questionNodes.size() - 1);
    }
    
    private Optional<OrPredicateSegment> extractOrConditionSegment(final Map<ParserRuleContext, Integer> parameterMarkerIndexes, final ParserRuleContext whereNode) {
    private Optional<OrPredicateSegment> extractOrConditionSegment(final ParserRuleContext whereNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
        Optional<ParserRuleContext> exprNode = ExtractorUtils.findFirstChildNode((ParserRuleContext) whereNode.getChild(1), RuleName.EXPR);
        return exprNode.isPresent() ? predicateExtractor.extract(parameterMarkerIndexes, exprNode.get()) : Optional.<OrPredicateSegment>absent();
        return exprNode.isPresent() ? predicateExtractor.extract(exprNode.get(), parameterMarkerIndexes) : Optional.<OrPredicateSegment>absent();
    }
}