Commit fc8dc428 authored by gaohongtao's avatar gaohongtao
Browse files

add to set sharding value using hint . Use HintShardingValueManger to set hint

parent 62407bd3
Loading
Loading
Loading
Loading
+166 −0
Original line number Diff line number Diff line
/**
 * Copyright 1999-2015 dangdang.com.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * </p>
 */

package com.dangdang.ddframe.rdb.sharding.api;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.dangdang.ddframe.rdb.sharding.exception.ShardingJdbcException;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.Condition;
import com.dangdang.ddframe.rdb.sharding.router.single.SingleRouterUtil;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;

/**
 * 通过线索传递分片值的管理器.
 *
 * @author gaohongtao
 */
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public final class HintShardingValueManager {
    
    private static final ThreadLocal<ShardingValueContainer> SHARING_VALUE_CONTAINER = new ThreadLocal<>();
    
    /**
     * 初始化容器.
     */
    public static void init() {
        if (null != SHARING_VALUE_CONTAINER.get()) {
            throw new ShardingJdbcException("CAN NOT register sharding value repeatedly");
        }
        SHARING_VALUE_CONTAINER.set(new ShardingValueContainer());
    }
    
    /**
     * 注册分库分片值.
     * 
     * @param logicTable 逻辑表明
     * @param shardingColumn 分库键
     * @param values 分库值
     */
    public static void registerShardingValueOfDatabase(final String logicTable, final String shardingColumn, final Comparable<?>... values) {
        registerShardingValueOfDatabase(logicTable, shardingColumn, Condition.BinaryOperator.EQUAL, values);
    }
    
    /**
     * 注册分库分片值.
     * 
     * @param logicTable 逻辑表明
     * @param shardingColumn 分库键
     * @param binaryOperator 分库操作符
     * @param values 分库值
     */
    public static void registerShardingValueOfDatabase(final String logicTable, final String shardingColumn, final Condition.BinaryOperator binaryOperator, final Comparable<?>... values) {
        if (null == SHARING_VALUE_CONTAINER.get()) {
            throw new ShardingJdbcException("Please first invoke HintShardingValueManager.init()");
        }
        registerShardingValue(SHARING_VALUE_CONTAINER.get().databaseShardingValues, logicTable, shardingColumn, binaryOperator, values);
    }
    
    /**
     * 获取分库键值.
     * 
     * @param logicTable 逻辑表名
     * @return 分库键值
     */
    public static Optional<List<ShardingValue<?>>> getShardingValueOfDatabase(final String logicTable) {
        if (null == SHARING_VALUE_CONTAINER.get()) {
            return Optional.absent();
        }
        return Optional.fromNullable(SHARING_VALUE_CONTAINER.get().databaseShardingValues.get(logicTable));
    }
    
    /**
     * 注册分表分片值.
     * 
     * @param logicTable 逻辑表明
     * @param shardingColumn 分库键
     * @param values 分库值
     */
    public static void registerShardingValueOfTable(final String logicTable, final String shardingColumn, final Comparable<?>... values) {
        registerShardingValueOfTable(logicTable, shardingColumn, Condition.BinaryOperator.EQUAL, values);
    }
    
    /**
     * 注册分表分片值.
     * 
     * @param logicTable 逻辑表明
     * @param shardingColumn 分库键
     * @param binaryOperator 分库操作符
     * @param values 分库值
     */
    public static void registerShardingValueOfTable(final String logicTable, final String shardingColumn, final Condition.BinaryOperator binaryOperator, final Comparable<?>... values) {
        if (null == SHARING_VALUE_CONTAINER.get()) {
            throw new ShardingJdbcException("Please first invoke HintShardingValueManager.init()");
        }
        registerShardingValue(SHARING_VALUE_CONTAINER.get().tableShardingValues, logicTable, shardingColumn, binaryOperator, values);
    }
    
    /**
     * 获取分表键值.
     * 
     * @param logicTable 逻辑表名
     * @return 分库键值
     */
    public static Optional<List<ShardingValue<?>>> getShardingValueOfTable(final String logicTable) {
        if (null == SHARING_VALUE_CONTAINER.get()) {
            return Optional.absent();
        }
        return Optional.fromNullable(SHARING_VALUE_CONTAINER.get().tableShardingValues.get(logicTable));
    }
    
    private static void registerShardingValue(final Map<String, List<ShardingValue<?>>> container, final String logicTable, 
                                              final String shardingColumn, final Condition.BinaryOperator binaryOperator, final Comparable<?>... values) {
        Preconditions.checkArgument(!Strings.isNullOrEmpty(logicTable));
        Preconditions.checkArgument(!Strings.isNullOrEmpty(shardingColumn));
        Preconditions.checkArgument(null != values && values.length > 0);
        
        List<ShardingValue<?>> shardingValues;
        if (container.containsKey(logicTable)) {
            shardingValues = container.get(logicTable);
        } else {
            shardingValues = new ArrayList<>();
            container.put(logicTable, shardingValues);
        }
        Condition condition = new Condition(new Condition.Column(shardingColumn, logicTable), binaryOperator);
        condition.getValues().addAll(Arrays.asList(values));
        shardingValues.add(SingleRouterUtil.convertConditionToShardingValue(condition));
    }
    
    /**
     * 清理容器.
     * 
     */
    public static void clear() {
        SHARING_VALUE_CONTAINER.remove();
    }
    
    private static class ShardingValueContainer {
        
        private final Map<String, List<ShardingValue<?>>> databaseShardingValues = new HashMap<>();
        
        private final Map<String, List<ShardingValue<?>>> tableShardingValues = new HashMap<>();
    }
    
}
+13 −5
Original line number Diff line number Diff line
@@ -20,11 +20,12 @@ package com.dangdang.ddframe.rdb.sharding.api.strategy.common;
import java.util.Arrays;
import java.util.Collection;

import com.dangdang.ddframe.rdb.sharding.api.ShardingValue;
import com.dangdang.ddframe.rdb.sharding.exception.ShardingJdbcException;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.SQLStatementType;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import com.dangdang.ddframe.rdb.sharding.api.ShardingValue;

/**
 * 分片策略.
 * 
@@ -45,15 +46,22 @@ public class ShardingStrategy {
    /**
     * 根据分片值计算数据源名称集合.
     *
     *
     * @param sqlStatementType
     * @param availableTargetNames 所有的可用数据源名称集合
     * @param shardingValues 分库片值集合
     * @return 分库后指向的数据源名称集合
     */
    @SuppressWarnings({ "unchecked", "rawtypes" })
    public Collection<String> doSharding(final Collection<String> availableTargetNames, final Collection<ShardingValue<? extends Comparable<?>>> shardingValues) {
    public Collection<String> doSharding(final SQLStatementType sqlStatementType, final Collection<String> availableTargetNames, 
                                         final Collection<ShardingValue<? extends Comparable<?>>> shardingValues) {
        if (shardingValues.isEmpty()) {
            if (SQLStatementType.INSERT.equals(sqlStatementType)) {
                throw new ShardingJdbcException("INSERT statement must contains sharding value");
            } else {
                return availableTargetNames;
            }
        }
        if (shardingAlgorithm instanceof SingleKeyShardingAlgorithm) {
            SingleKeyShardingAlgorithm<?> singleKeyShardingAlgorithm = (SingleKeyShardingAlgorithm<?>) shardingAlgorithm;
            ShardingValue shardingValue = shardingValues.iterator().next();
+10 −2
Original line number Diff line number Diff line
@@ -28,12 +28,14 @@ import java.util.Map;
import javax.sql.DataSource;

import com.codahale.metrics.Timer.Context;
import com.dangdang.ddframe.rdb.sharding.api.rule.DataSourceRule;
import com.dangdang.ddframe.rdb.sharding.exception.ShardingJdbcException;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.AbstractConnectionAdapter;
import com.dangdang.ddframe.rdb.sharding.metrics.MetricsContext;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Collections2;
import com.google.common.collect.Lists;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
@@ -73,9 +75,11 @@ public final class ShardingConnection extends AbstractConnectionAdapter {
    @Override
    public DatabaseMetaData getMetaData() throws SQLException {
        if (connectionMap.isEmpty()) {
            return getDatabaseMetaDataFromDataSource(shardingContext.getShardingRule().getDataSourceRule().getDataSources());
            DataSourceRule dataSourceRule = shardingContext.getShardingRule().getDataSourceRule();
            String dsName = dataSourceRule.getDataSourceNames().iterator().next();
            connectionMap.put(dsName, dataSourceRule.getDataSource(dsName).getConnection());
        }
        return getDatabaseMetaDataFromConnection(connectionMap.values());
        return getDatabaseMetaDataFromConnection(connectionMap.values().iterator().next());
    }
    
    public static DatabaseMetaData getDatabaseMetaDataFromDataSource(final Collection<DataSource> dataSources) {
@@ -105,6 +109,10 @@ public final class ShardingConnection extends AbstractConnectionAdapter {
        }
    }
    
    private static DatabaseMetaData getDatabaseMetaDataFromConnection(final Connection connection) {
        return getDatabaseMetaDataFromConnection(Lists.newArrayList(connection));
    }
    
    private static DatabaseMetaData getDatabaseMetaDataFromConnection(final Collection<Connection> connections) {
        String databaseProductName = null;
        DatabaseMetaData result = null;
+23 −1
Original line number Diff line number Diff line
@@ -21,12 +21,17 @@ import java.util.Collection;
import java.util.List;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.dangdang.ddframe.rdb.sharding.exception.SQLParserException;
import com.dangdang.ddframe.rdb.sharding.parser.result.SQLParsedResult;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.SQLStatementType;
import com.dangdang.ddframe.rdb.sharding.parser.visitor.SQLVisitor;
import com.dangdang.ddframe.rdb.sharding.parser.visitor.or.OrParser;
import com.google.common.base.Preconditions;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@@ -68,6 +73,23 @@ public final class SQLParseEngine {
        log.debug("Parsed SQL result: {}", result);
        log.debug("Parsed SQL: {}", sqlVisitor.getSQLBuilder());
        result.getRouteContext().setSqlBuilder(sqlVisitor.getSQLBuilder());
        result.getRouteContext().setSqlStatementType(getType());
        return result;
    }
    
    private SQLStatementType getType() {
        if (sqlStatement instanceof SQLSelectStatement) {
            return SQLStatementType.SELECT;
        }
        if (sqlStatement instanceof SQLInsertStatement) {
            return SQLStatementType.INSERT;
        }
        if (sqlStatement instanceof SQLUpdateStatement) {
            return SQLStatementType.UPDATE;
        }
        if (sqlStatement instanceof SQLDeleteStatement) {
            return SQLStatementType.DELETE;
        }
        throw new SQLParserException("Unsupported SQL statement: [%s]", sqlStatement);
    }
}
+4 −0
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ import java.util.Collection;
import java.util.LinkedHashSet;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;

@@ -32,9 +33,12 @@ import lombok.ToString;
@Getter
@Setter
@ToString
@RequiredArgsConstructor
public final class RouteContext {
    
    private final Collection<Table> tables = new LinkedHashSet<>();
    
    private SQLStatementType sqlStatementType;
    
    private SQLBuilder sqlBuilder;
}
Loading