Commit aaf5b7a4 authored by terrymanu's avatar terrymanu
Browse files

refactor SubqueryExtractor

parent 4cd5a8a3
Loading
Loading
Loading
Loading
+24 −19
Original line number Diff line number Diff line
@@ -62,18 +62,23 @@ public final class FromWhereExtractor implements OptionalSQLSegmentExtractor {
        predicateSegmentExtractor = new PredicateExtractor(result.getTableAliases());
        Collection<ParserRuleContext> questionNodes = ExtractorUtils.getAllDescendantNodes(ancestorNode, RuleName.QUESTION);
        result.setParameterCount(questionNodes.size());
        Map<ParserRuleContext, Integer> questionNodeIndexMap = new HashMap<>();
        int index = 0;
        for (ParserRuleContext each : questionNodes) {
            questionNodeIndexMap.put(each, index++);
        }
        Map<ParserRuleContext, Integer> questionNodeIndexMap = getQuestionNodeIndexMap(questionNodes);
        extractAndFillTableSegment(result, tableReferenceNodes, questionNodeIndexMap);
        extractAndFillWhere(result, questionNodeIndexMap, fromNode.get().getParent());
        return Optional.of(result);
    }
    
    private void extractAndFillTableSegment(final FromWhereSegment fromWhereSegment, final Collection<ParserRuleContext> tableReferenceNodes,
                                            final Map<ParserRuleContext, Integer> questionNodeIndexMap) {
    private Map<ParserRuleContext, Integer> getQuestionNodeIndexMap(final Collection<ParserRuleContext> questionNodes) {
        Map<ParserRuleContext, Integer> result = new HashMap<>(questionNodes.size(), 1);
        int index = 0;
        for (ParserRuleContext each : questionNodes) {
            result.put(each, index++);
        }
        return result;
    }
    
    private void extractAndFillTableSegment(final FromWhereSegment fromWhereSegment, 
                                            final Collection<ParserRuleContext> tableReferenceNodes, final Map<ParserRuleContext, Integer> questionNodeIndexMap) {
        for (ParserRuleContext each : tableReferenceNodes) {
            for (int i = 0; i < each.getChildCount(); i++) {
                if (each.getChild(i) instanceof TerminalNode) {
@@ -81,13 +86,13 @@ public final class FromWhereExtractor implements OptionalSQLSegmentExtractor {
                }
                ParserRuleContext childNode = (ParserRuleContext) each.getChild(i);
                if (RuleName.TABLE_REFERENCES.getName().equals(childNode.getClass().getSimpleName())) {
                    final Collection<ParserRuleContext> subTableReferenceNodes = ExtractorUtils.getAllDescendantNodes(childNode, RuleName.TABLE_REFERENCE);
                    Collection<ParserRuleContext> subTableReferenceNodes = ExtractorUtils.getAllDescendantNodes(childNode, RuleName.TABLE_REFERENCE);
                    if (!subTableReferenceNodes.isEmpty()) {
                        extractAndFillTableSegment(fromWhereSegment, subTableReferenceNodes, questionNodeIndexMap);
                    }
                    continue;
                }
                if (RuleName.TABLE_FACTOR.getName().equals(childNode.getClass().getSimpleName()) && fillSubquery(fromWhereSegment, childNode)) {
                if (RuleName.TABLE_FACTOR.getName().equals(childNode.getClass().getSimpleName()) && fillSubQuery(fromWhereSegment, childNode)) {
                    continue;
                }
                fillTable(fromWhereSegment, childNode, questionNodeIndexMap);
@@ -95,14 +100,14 @@ public final class FromWhereExtractor implements OptionalSQLSegmentExtractor {
        }
    }
    
    private boolean fillSubquery(final FromWhereSegment fromWhereSegment, final ParserRuleContext tableFactorNode) {
    private boolean fillSubQuery(final FromWhereSegment fromWhereSegment, final ParserRuleContext tableFactorNode) {
        Optional<ParserRuleContext> subqueryNode = ExtractorUtils.findFirstChildNode(tableFactorNode, RuleName.SUBQUERY);
        if (!subqueryNode.isPresent()) {
            return false;
        }
        Optional<SubquerySegment> result = new SubqueryExtractor().extract(subqueryNode.get());
        if (result.isPresent()) {
            fromWhereSegment.getSubquerys().add(result.get());
            fromWhereSegment.getSubQuerys().add(result.get());
        }
        return true;
    }
@@ -134,14 +139,6 @@ public final class FromWhereExtractor implements OptionalSQLSegmentExtractor {
        fromWhereSegment.getTableAliases().put(alias, tableSegment.getName());
    }
    
    private Optional<OrConditionSegment> buildCondition(final ParserRuleContext node, final Map<ParserRuleContext, Integer> questionNodeIndexMap, final Map<String, String> tableAliases) {
        Optional<ParserRuleContext> exprNode = ExtractorUtils.findFirstChildNode(node, RuleName.EXPR);
        if (exprNode.isPresent()) {
            return predicateSegmentExtractor.extractCondition(questionNodeIndexMap, exprNode.get());
        }
        return Optional.absent();
    }
    
    private void extractAndFillWhere(final FromWhereSegment fromWhereSegment, final Map<ParserRuleContext, Integer> questionNodeIndexMap, final ParserRuleContext ancestorNode) {
        Optional<ParserRuleContext> whereNode = ExtractorUtils.findFirstChildNodeNoneRecursive(ancestorNode, RuleName.WHERE_CLAUSE);
        if (!whereNode.isPresent()) {
@@ -152,4 +149,12 @@ public final class FromWhereExtractor implements OptionalSQLSegmentExtractor {
            fromWhereSegment.getConditions().getAndConditions().addAll(conditions.get().getAndConditions());
        }
    }
    
    private Optional<OrConditionSegment> buildCondition(final ParserRuleContext node, final Map<ParserRuleContext, Integer> questionNodeIndexMap, final Map<String, String> tableAliases) {
        Optional<ParserRuleContext> exprNode = ExtractorUtils.findFirstChildNode(node, RuleName.EXPR);
        if (exprNode.isPresent()) {
            return predicateSegmentExtractor.extractCondition(questionNodeIndexMap, exprNode.get());
        }
        return Optional.absent();
    }
}
+24 −6
Original line number Diff line number Diff line
@@ -35,6 +35,12 @@ import org.antlr.v4.runtime.ParserRuleContext;
 */
public final class SubqueryExtractor implements OptionalSQLSegmentExtractor {
    
    private final FromWhereExtractor fromWhereExtractor = new FromWhereExtractor();
    
    private final GroupByExtractor groupByExtractor = new GroupByExtractor();
    
    private final OrderByExtractor orderByExtractor = new OrderByExtractor();
    
    @Override
    public Optional<SubquerySegment> extract(final ParserRuleContext subqueryNode) {
        if (!RuleName.SUBQUERY.getName().endsWith(subqueryNode.getClass().getSimpleName())) {
@@ -49,15 +55,27 @@ public final class SubqueryExtractor implements OptionalSQLSegmentExtractor {
            }
            parentNode = parentNode.getParent();
        }
        SubquerySegment result = new SubquerySegment(subqueryInFrom);
        Optional<SelectClauseSegment> selectClauseSegment = new SelectClauseExtractor().extract(subqueryNode);
        Optional<FromWhereSegment> fromWhereSegment = new FromWhereExtractor().extract(subqueryNode);
        Optional<GroupBySegment> groupBySegment = new GroupByExtractor().extract(subqueryNode);
        Optional<OrderBySegment> orderBySegment = new OrderByExtractor().extract(subqueryNode);
        if (selectClauseSegment.isPresent()) {
            result.setSelectClauseSegment(selectClauseSegment.get());
        }
        Optional<FromWhereSegment> fromWhereSegment = fromWhereExtractor.extract(subqueryNode);
        if (fromWhereSegment.isPresent()) {
            result.setFromWhereSegment(fromWhereSegment.get());
        }
        Optional<GroupBySegment> groupBySegment = groupByExtractor.extract(subqueryNode);
        if (groupBySegment.isPresent()) {
            result.setGroupBySegment(groupBySegment.get());
        }
        Optional<OrderBySegment> orderBySegment = orderByExtractor.extract(subqueryNode);
        if (orderBySegment.isPresent()) {
            result.setOrderBySegment(orderBySegment.get());
        }
        Optional<ParserRuleContext> aliasNode = ExtractorUtils.findFirstChildNode(subqueryNode.getParent(), RuleName.ALIAS);
        Optional<String> alias = Optional.absent();
        if (aliasNode.isPresent()) {
            alias = Optional.of(aliasNode.get().getText());
            result.setAlias(aliasNode.get().getText());
        }
        return Optional.of(new SubquerySegment(selectClauseSegment, fromWhereSegment, groupBySegment, orderBySegment, alias, subqueryInFrom));
        return Optional.of(result);
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -67,8 +67,8 @@ public final class FromWhereFiller implements SQLStatementFiller<FromWhereSegmen
            OrCondition orCondition = filterShardingCondition(sqlStatement, sqlSegment.getConditions(), sql, shardingRule, columnNameToTable, columnNameCount, shardingTableMetaData);
            sqlStatement.getConditions().getOrCondition().getAndConditions().addAll(orCondition.getAndConditions());
        }
        if (!sqlSegment.getSubquerys().isEmpty()) {
            for (SubquerySegment each : sqlSegment.getSubquerys()) {
        if (!sqlSegment.getSubQuerys().isEmpty()) {
            for (SubquerySegment each : sqlSegment.getSubQuerys()) {
                new SubqueryFiller().fill(each, sqlStatement, sql, shardingRule, shardingTableMetaData);
            }
        }
+6 −6
Original line number Diff line number Diff line
@@ -17,16 +17,16 @@

package io.shardingsphere.core.parsing.antlr.sql.segment;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

import io.shardingsphere.core.parsing.antlr.sql.segment.condition.OrConditionSegment;
import io.shardingsphere.core.parsing.antlr.sql.segment.expr.SubquerySegment;
import lombok.Getter;
import lombok.Setter;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

/**
 * From where segment.
 * 
@@ -39,7 +39,7 @@ public final class FromWhereSegment implements SQLSegment {
    
    private final OrConditionSegment conditions = new OrConditionSegment();
    
    private final Collection<SubquerySegment> subquerys = new LinkedList<>();
    private final Collection<SubquerySegment> subQuerys = new LinkedList<>();
    
    @Setter
    private Integer parameterCount;
+43 −8
Original line number Diff line number Diff line
@@ -18,13 +18,13 @@
package io.shardingsphere.core.parsing.antlr.sql.segment.expr;

import com.google.common.base.Optional;

import io.shardingsphere.core.parsing.antlr.sql.segment.FromWhereSegment;
import io.shardingsphere.core.parsing.antlr.sql.segment.SelectClauseSegment;
import io.shardingsphere.core.parsing.antlr.sql.segment.order.GroupBySegment;
import io.shardingsphere.core.parsing.antlr.sql.segment.order.OrderBySegment;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;

/**
 * Subquery expression segment.
@@ -33,17 +33,52 @@ import lombok.RequiredArgsConstructor;
 */
@RequiredArgsConstructor
@Getter
public final class SubquerySegment implements ExpressionSegment {
@Setter
public final class SubquerySegment extends ExpressionWithAliasSegment {
    
    private final Optional<SelectClauseSegment> selectClauseSegment;
    private final boolean subqueryInFrom;
    
    private final Optional<FromWhereSegment> fromWhereSegment;
    private SelectClauseSegment selectClauseSegment;
    
    private final Optional<GroupBySegment> groupBySegment;
    private FromWhereSegment fromWhereSegment;
    
    private final Optional<OrderBySegment> orderBySegment;
    private GroupBySegment groupBySegment;
    
    private final Optional<String> alias;
    private OrderBySegment orderBySegment;
    
    private final boolean subqueryInFrom;
    /**
     * Get select clause segment.
     * 
     * @return select clause segment
     */
    public Optional<SelectClauseSegment> getSelectClauseSegment() {
        return Optional.fromNullable(selectClauseSegment);
    }
    
    /**
     * Get from where segment.
     *
     * @return from where segment
     */
    public Optional<FromWhereSegment> getFromWhereSegment() {
        return Optional.fromNullable(fromWhereSegment);
    }
    
    /**
     * Get group by segment.
     *
     * @return group by segment
     */
    public Optional<GroupBySegment> getGroupBySegment() {
        return Optional.fromNullable(groupBySegment);
    }
    
    /**
     * Get order by segment.
     *
     * @return order by segment
     */
    public Optional<OrderBySegment> getOrderBySegment() {
        return Optional.fromNullable(orderBySegment);
    }
}