Commit 87619c14 authored by terrymanu's avatar terrymanu
Browse files

refactor GenerateKeys for PreparedSQLRouter

parent 57769207
Loading
Loading
Loading
Loading
+28 −0
Original line number Diff line number Diff line
@@ -18,14 +18,19 @@
package com.dangdang.ddframe.rdb.sharding.rewrite;

import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.api.rule.TableRule;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.token.ItemsToken;
import com.google.common.base.Optional;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

/**
@@ -91,4 +96,27 @@ public final class GenerateKeysUtils {
        }
        return false;
    }
    
    /**
     * 获取自增主键.
     * 
     * @param shardingRule 分片规则
     * @param insertSQLContext 解析结果
     * @return 自增主键集合
     */
    public static List<Number> generateKeys(final ShardingRule shardingRule, final InsertSQLContext insertSQLContext) {
        Optional<TableRule> tableRuleOptional = shardingRule.tryFindTableRule(insertSQLContext.getTables().iterator().next().getName());
        if (!tableRuleOptional.isPresent()) {
            return Collections.emptyList();
        }
        TableRule tableRule = tableRuleOptional.get();
        GeneratedKeyContext generatedKeyContext = insertSQLContext.getGeneratedKeyContext();
        List<Number> result = new ArrayList<>(generatedKeyContext.getColumns().size());
        for (String each : generatedKeyContext.getColumns()) {
            Number generatedId = tableRule.generateId(each);
            result.add(generatedId);
            generatedKeyContext.putValue(each, generatedId);
        }
        return result;
    }
}
+3 −27
Original line number Diff line number Diff line
@@ -18,15 +18,11 @@
package com.dangdang.ddframe.rdb.sharding.router;

import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.api.rule.TableRule;
import com.dangdang.ddframe.rdb.sharding.jdbc.ShardingContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SQLContext;
import com.google.common.base.Optional;
import com.dangdang.ddframe.rdb.sharding.rewrite.GenerateKeysUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
@@ -60,30 +56,10 @@ public final class PreparedSQLRouter {
    public SQLRouteResult route(final List<Object> parameters) {
        if (null == sqlContext) {
            sqlContext = routeEngine.parse(logicSQL, parameters);
        } else {
            List<Number> generatedIds = generateId();
            parameters.addAll(generatedIds);
        } else if (sqlContext instanceof InsertSQLContext) {
            parameters.addAll(GenerateKeysUtils.generateKeys(shardingRule, (InsertSQLContext) sqlContext));
        }
        return routeEngine.route(logicSQL, sqlContext, parameters);
    }
    
    private List<Number> generateId() {
        if (!(sqlContext instanceof InsertSQLContext)) {
            return Collections.emptyList();
        }
        Optional<TableRule> tableRuleOptional = shardingRule.tryFindTableRule(sqlContext.getTables().iterator().next().getName());
        if (!tableRuleOptional.isPresent()) {
            return Collections.emptyList();
        }
        TableRule tableRule = tableRuleOptional.get();
        GeneratedKeyContext generatedKeyContext = ((InsertSQLContext) sqlContext).getGeneratedKeyContext();
        List<Number> result = new ArrayList<>(generatedKeyContext.getColumns().size());
        for (String each : generatedKeyContext.getColumns()) {
            Number generatedId = tableRule.generateId(each);
            result.add(generatedId);
            generatedKeyContext.putValue(each, generatedId);
        }
        return result;
    }
}