Commit ed7d99d8 authored by terrymanu's avatar terrymanu
Browse files

for #2084, add PlaceholderIndexesAware

parent 7069ea46
Loading
Loading
Loading
Loading
+22 −2
Original line number Diff line number Diff line
@@ -18,14 +18,21 @@
package org.apache.shardingsphere.core.parse.antlr.extractor;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.CollectionSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.OptionalSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.PlaceholderIndexesAware;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.SQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.parser.SQLAST;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.SQLSegment;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

/**
 * SQL segments extractor engine.
@@ -42,17 +49,30 @@ public final class SQLSegmentsExtractorEngine {
     */
    public Collection<SQLSegment> extract(final SQLAST ast) {
        Collection<SQLSegment> result = new LinkedList<>();
        Preconditions.checkState(ast.getSQLStatementRule().isPresent());
        for (SQLSegmentExtractor each : ast.getSQLStatementRule().get().getExtractors()) {
            if (each instanceof PlaceholderIndexesAware) {
                ((PlaceholderIndexesAware) each).setPlaceholderIndexes(getPlaceholderIndexes(ast.getParserRuleContext()));
            }
            if (each instanceof OptionalSQLSegmentExtractor) {
                Optional<? extends SQLSegment> sqlSegment = ((OptionalSQLSegmentExtractor) each).extract(ast.getParserRuleContext());
                if (sqlSegment.isPresent()) {
                    result.add(sqlSegment.get());
                }
            }
            if (each instanceof CollectionSQLSegmentExtractor) {
            } else if (each instanceof CollectionSQLSegmentExtractor) {
                result.addAll(((CollectionSQLSegmentExtractor) each).extract(ast.getParserRuleContext()));
            }
        }
        return result;
    }
    
    private Map<ParserRuleContext, Integer> getPlaceholderIndexes(final ParserRuleContext rootNode) {
        Collection<ParserRuleContext> placeholderNodes = ExtractorUtils.getAllDescendantNodes(rootNode, RuleName.QUESTION);
        Map<ParserRuleContext, Integer> result = new HashMap<>(placeholderNodes.size(), 1);
        int index = 0;
        for (ParserRuleContext each : placeholderNodes) {
            result.put(each, index++);
        }
        return result;
    }
}
+37 −0
Original line number Diff line number Diff line
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.shardingsphere.core.parse.antlr.extractor.api;

import org.antlr.v4.runtime.ParserRuleContext;

import java.util.Map;

/**
 * Placeholder indexes aware.
 *
 * @author zhangliang
 */
public interface PlaceholderIndexesAware {
    
    /**
     * Set placeholder indexes.
     * 
     * @param placeholderIndexes placeholder indexes
     */
    void setPlaceholderIndexes(Map<ParserRuleContext, Integer> placeholderIndexes);
}
+6 −13
Original line number Diff line number Diff line
@@ -18,15 +18,16 @@
package org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml;

import com.google.common.base.Optional;
import lombok.Setter;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.OptionalSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.PlaceholderIndexesAware;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.SetAssignmentsSegment;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

@@ -35,10 +36,13 @@ import java.util.Map;
 *
 * @author zhangliang
 */
public final class SetAssignmentsExtractor implements OptionalSQLSegmentExtractor {
@Setter
public final class SetAssignmentsExtractor implements OptionalSQLSegmentExtractor, PlaceholderIndexesAware {
    
    private final AssignmentExtractor assignmentExtractor = new AssignmentExtractor();
    
    private Map<ParserRuleContext, Integer> placeholderIndexes;
            
    @Override
    public Optional<SetAssignmentsSegment> extract(final ParserRuleContext ancestorNode) {
        Optional<ParserRuleContext> setAssignmentsClauseNode = ExtractorUtils.findFirstChildNode(ancestorNode, RuleName.SET_ASSIGNMENTS_CLAUSE);
@@ -46,7 +50,6 @@ public final class SetAssignmentsExtractor implements OptionalSQLSegmentExtracto
            return Optional.absent();
        }
        Collection<AssignmentSegment> assignmentSegments = new LinkedList<>();
        Map<ParserRuleContext, Integer> placeholderIndexes = getPlaceholderIndexes(ancestorNode);
        for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(ancestorNode, RuleName.ASSIGNMENT)) {
            Optional<AssignmentSegment> assignmentSegment = assignmentExtractor.extract(placeholderIndexes, each);
            if (assignmentSegment.isPresent()) {
@@ -55,14 +58,4 @@ public final class SetAssignmentsExtractor implements OptionalSQLSegmentExtracto
        }
        return Optional.of(new SetAssignmentsSegment(setAssignmentsClauseNode.get().getStart().getStartIndex(), assignmentSegments));
    }
    
    private Map<ParserRuleContext, Integer> getPlaceholderIndexes(final ParserRuleContext rootNode) {
        Collection<ParserRuleContext> placeholderNodes = ExtractorUtils.getAllDescendantNodes(rootNode, RuleName.QUESTION);
        Map<ParserRuleContext, Integer> result = new HashMap<>(placeholderNodes.size(), 1);
        int index = 0;
        for (ParserRuleContext each : placeholderNodes) {
            result.put(each, index++);
        }
        return result;
    }
}
+6 −13
Original line number Diff line number Diff line
@@ -18,8 +18,10 @@
package org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.insert;

import com.google.common.base.Optional;
import lombok.Setter;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.CollectionSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.PlaceholderIndexesAware;
import org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.ExpressionExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
@@ -28,7 +30,6 @@ import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExp

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

@@ -37,10 +38,13 @@ import java.util.Map;
 *
 * @author zhangliang
 */
public final class InsertValuesExtractor implements CollectionSQLSegmentExtractor {
@Setter
public final class InsertValuesExtractor implements CollectionSQLSegmentExtractor, PlaceholderIndexesAware {
    
    private final ExpressionExtractor expressionExtractor = new ExpressionExtractor();
    
    private Map<ParserRuleContext, Integer> placeholderIndexes;
    
    @Override
    public Collection<InsertValuesSegment> extract(final ParserRuleContext ancestorNode) {
        Optional<ParserRuleContext> insertValuesClauseNode = ExtractorUtils.findFirstChildNode(ancestorNode, RuleName.INSERT_VALUES_CLAUSE);
@@ -48,23 +52,12 @@ public final class InsertValuesExtractor implements CollectionSQLSegmentExtracto
            return Collections.emptyList();
        }
        Collection<InsertValuesSegment> result = new LinkedList<>();
        Map<ParserRuleContext, Integer> placeholderIndexes = getPlaceholderIndexes(ancestorNode);
        for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(insertValuesClauseNode.get(), RuleName.ASSIGNMENT_VALUES)) {
            result.add(new InsertValuesSegment(extractCommonExpressionSegments(each, placeholderIndexes)));
        }
        return result;
    }
    
    private Map<ParserRuleContext, Integer> getPlaceholderIndexes(final ParserRuleContext rootNode) {
        Collection<ParserRuleContext> placeholderNodes = ExtractorUtils.getAllDescendantNodes(rootNode, RuleName.QUESTION);
        Map<ParserRuleContext, Integer> result = new HashMap<>(placeholderNodes.size(), 1);
        int index = 0;
        for (ParserRuleContext each : placeholderNodes) {
            result.put(each, index++);
        }
        return result;
    }
    
    private Collection<CommonExpressionSegment> extractCommonExpressionSegments(final ParserRuleContext assignmentValuesNode, final Map<ParserRuleContext, Integer> placeholderIndexes) {
        Collection<CommonExpressionSegment> result = new LinkedList<>();
        for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(assignmentValuesNode, RuleName.ASSIGNMENT_VALUE)) {
+6 −13
Original line number Diff line number Diff line
@@ -18,8 +18,10 @@
package org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.select;

import com.google.common.base.Optional;
import lombok.Setter;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.OptionalSQLSegmentExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.api.PlaceholderIndexesAware;
import org.apache.shardingsphere.core.parse.antlr.extractor.impl.dml.PredicateExtractor;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.ExtractorUtils;
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
@@ -27,7 +29,6 @@ import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.WhereSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.condition.OrConditionSegment;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

/**
@@ -35,10 +36,13 @@ import java.util.Map;
 *
 * @author duhongjun
 */
public abstract class AbstractWhereExtractor implements OptionalSQLSegmentExtractor {
@Setter
public abstract class AbstractWhereExtractor implements OptionalSQLSegmentExtractor, PlaceholderIndexesAware {
    
    private final PredicateExtractor predicateExtractor = new PredicateExtractor();
    
    private Map<ParserRuleContext, Integer> placeholderIndexes;
    
    @Override
    public final Optional<WhereSegment> extract(final ParserRuleContext ancestorNode) {
        return extract(ancestorNode, ancestorNode);
@@ -53,7 +57,6 @@ public abstract class AbstractWhereExtractor implements OptionalSQLSegmentExtrac
     */
    public Optional<WhereSegment> extract(final ParserRuleContext ancestorNode, final ParserRuleContext rootNode) {
        WhereSegment result = new WhereSegment();
        Map<ParserRuleContext, Integer> placeholderIndexes = getPlaceholderIndexes(rootNode);
        result.setParameterCount(placeholderIndexes.size());
        Optional<ParserRuleContext> whereNode = extractWhere(ancestorNode);
        if (whereNode.isPresent()) {
@@ -63,16 +66,6 @@ public abstract class AbstractWhereExtractor implements OptionalSQLSegmentExtrac
        return Optional.of(result);
    }
    
    private Map<ParserRuleContext, Integer> getPlaceholderIndexes(final ParserRuleContext rootNode) {
        Collection<ParserRuleContext> placeholderNodes = ExtractorUtils.getAllDescendantNodes(rootNode, RuleName.QUESTION);
        Map<ParserRuleContext, Integer> result = new HashMap<>(placeholderNodes.size(), 1);
        int index = 0;
        for (ParserRuleContext each : placeholderNodes) {
            result.put(each, index++);
        }
        return result;
    }
    
    protected abstract Optional<ParserRuleContext> extractWhere(ParserRuleContext ancestorNode);
    
    private void setPropertiesForRevert(final WhereSegment whereSegment, final Map<ParserRuleContext, Integer> placeholderIndexes, final ParserRuleContext whereNode) {
Loading