Commit 4a3f1a58 authored by terrymanu's avatar terrymanu
Browse files

refactor JDBCBackendHandler

parent 367cac7a
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -17,22 +17,27 @@

package io.shardingsphere.proxy.backend.common;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;

import javax.sql.DataSource;
import java.sql.Connection;
import java.util.LinkedHashMap;
import java.util.HashMap;
import java.util.Map;

/**
 * Hold the connection when proxyMode is CONNECTION_STRICTLY.
 * Hold the connection when proxy mode is CONNECTION_STRICTLY.
 *
 * @author zhaojun
 */
public class ProxyConnectionHolder {
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ProxyConnectionHolder {
    
    private static final ThreadLocal<Map<DataSource, Connection>> RESOURCE = new ThreadLocal<Map<DataSource, Connection>>() {
        
        @Override
        protected Map<DataSource, Connection> initialValue() {
            return new LinkedHashMap<>();
            return new HashMap<>();
        }
    };
    
+31 −27
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@

package io.shardingsphere.proxy.backend.common.jdbc;

import io.netty.channel.EventLoopGroup;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.constant.TransactionType;
import io.shardingsphere.core.exception.ShardingException;
@@ -91,6 +92,8 @@ public abstract class JDBCBackendHandler implements BackendHandler {
    
    private final RuleRegistry ruleRegistry;
    
    private final EventLoopGroup userGroup;
    
    public JDBCBackendHandler(final String sql, final BaseJDBCResource jdbcResource) {
        this.sql = sql;
        this.jdbcResource = jdbcResource;
@@ -98,13 +101,16 @@ public abstract class JDBCBackendHandler implements BackendHandler {
        hasMoreResultValueFlag = true;
        resultLists = new CopyOnWriteArrayList<>();
        ruleRegistry = RuleRegistry.getInstance();
        userGroup = ExecutorContext.getInstance().getUserGroup();
    }
    
    @Override
    public CommandResponsePackets execute() {
    public final CommandResponsePackets execute() {
        try {
            return execute(ruleRegistry.isMasterSlaveOnly() ? doMasterSlaveRoute() : doShardingRoute());
        } catch (final Exception ex) {
        } catch (final SQLException ex) {
            return new CommandResponsePackets(new ErrPacket(1, ex));
        } catch (final SystemException ex) {
            return new CommandResponsePackets(new ErrPacket(1, new SQLException(ex)));
        }
    }
@@ -119,8 +125,8 @@ public abstract class JDBCBackendHandler implements BackendHandler {
        }
        List<Future<CommandResponsePackets>> futureList = new ArrayList<>(1024);
        for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
            Statement statement = prepareResource(each.getDataSource(), each.getSqlUnit().getSql(), routeResult.getSqlStatement());
            futureList.add(ExecutorContext.getInstance().getUserGroup().submit(newSubmitTask(statement, routeResult.getSqlStatement(), each.getSqlUnit().getSql())));
            Statement statement = prepareResource(each, routeResult.getSqlStatement());
            futureList.add(userGroup.submit(newSubmitTask(statement, routeResult.getSqlStatement(), each.getSqlUnit().getSql())));
        }
        List<CommandResponsePackets> packets = buildCommandResponsePackets(futureList);
        CommandResponsePackets result = merge(routeResult.getSqlStatement(), packets);
@@ -135,18 +141,7 @@ public abstract class JDBCBackendHandler implements BackendHandler {
        return TransactionType.XA == ruleRegistry.getTransactionType() && SQLType.DDL == sqlType && Status.STATUS_NO_TRANSACTION != AtomikosUserTransaction.getInstance().getStatus();
    }
    
    private SQLRouteResult doMasterSlaveRoute() {
        SQLStatement sqlStatement = new SQLJudgeEngine(sql).judge();
        SQLRouteResult result = new SQLRouteResult(sqlStatement);
        for (String each : new MasterSlaveRouter(ruleRegistry.getMasterSlaveRule()).route(sqlStatement.getType())) {
            result.getExecutionUnits().add(new SQLExecutionUnit(each, new SQLUnit(sql, Collections.<List<Object>>emptyList())));
        }
        return result;
    }
    
    protected abstract SQLRouteResult doShardingRoute();
    
    protected abstract Statement prepareResource(String dataSourceName, String unitSQL, SQLStatement sqlStatement) throws SQLException;
    protected abstract Statement prepareResource(SQLExecutionUnit sqlExecutionUnit, SQLStatement sqlStatement) throws SQLException;
    
    protected abstract Callable<CommandResponsePackets> newSubmitTask(Statement statement, SQLStatement sqlStatement, String unitSQL);
    
@@ -225,8 +220,19 @@ public abstract class JDBCBackendHandler implements BackendHandler {
        return result;
    }
    
    private SQLRouteResult doMasterSlaveRoute() {
        SQLStatement sqlStatement = new SQLJudgeEngine(sql).judge();
        SQLRouteResult result = new SQLRouteResult(sqlStatement);
        for (String each : new MasterSlaveRouter(ruleRegistry.getMasterSlaveRule()).route(sqlStatement.getType())) {
            result.getExecutionUnits().add(new SQLExecutionUnit(each, new SQLUnit(sql, Collections.<List<Object>>emptyList())));
        }
        return result;
    }
    
    protected abstract SQLRouteResult doShardingRoute();
    
    @Override
    public boolean hasMoreResultValue() throws SQLException {
    public final boolean hasMoreResultValue() throws SQLException {
        if (!isMerged || !hasMoreResultValueFlag) {
            jdbcResource.clear();
            return false;
@@ -238,7 +244,7 @@ public abstract class JDBCBackendHandler implements BackendHandler {
    }
    
    @Override
    public DatabaseProtocolPacket getResultValue() {
    public final DatabaseProtocolPacket getResultValue() {
        if (!hasMoreResultValueFlag) {
            return new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue());
        }
@@ -255,17 +261,15 @@ public abstract class JDBCBackendHandler implements BackendHandler {
    
    protected abstract DatabaseProtocolPacket newDatabaseProtocolPacket(int sequenceId, List<Object> data);
    
    protected Connection getConnection(final DataSource dataSource) throws SQLException {
        Connection result;
        if (ProxyMode.CONNECTION_STRICTLY == ruleRegistry.getProxyMode()) {
            result = ProxyConnectionHolder.getConnection(dataSource);
    protected final Connection getConnection(final DataSource dataSource) throws SQLException {
        if (ProxyMode.MEMORY_STRICTLY == ruleRegistry.getProxyMode()) {
            return dataSource.getConnection();
        }
        Connection result = ProxyConnectionHolder.getConnection(dataSource);
        if (null == result) {
            result = dataSource.getConnection();
            ProxyConnectionHolder.setConnection(dataSource, result);
        }
        } else {
            result = dataSource.getConnection();
        }
        return result;
    }
}
+5 −3
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import io.shardingsphere.core.merger.QueryResult;
import io.shardingsphere.core.parsing.parser.sql.SQLStatement;
import io.shardingsphere.core.parsing.parser.sql.dml.insert.InsertStatement;
import io.shardingsphere.core.routing.PreparedStatementRoutingEngine;
import io.shardingsphere.core.routing.SQLExecutionUnit;
import io.shardingsphere.core.routing.SQLRouteResult;
import io.shardingsphere.proxy.backend.common.ProxyMode;
import io.shardingsphere.proxy.backend.common.jdbc.JDBCBackendHandler;
@@ -93,10 +94,11 @@ public final class JDBCStatementBackendHandler extends JDBCBackendHandler {
    }
    
    @Override
    protected PreparedStatement prepareResource(final String dataSourceName, final String unitSQL, final SQLStatement sqlStatement) throws SQLException {
        DataSource dataSource = ruleRegistry.getDataSourceMap().get(dataSourceName);
    protected PreparedStatement prepareResource(final SQLExecutionUnit sqlExecutionUnit, final SQLStatement sqlStatement) throws SQLException {
        DataSource dataSource = ruleRegistry.getDataSourceMap().get(sqlExecutionUnit.getDataSource());
        Connection connection = getConnection(dataSource);
        PreparedStatement result = sqlStatement instanceof InsertStatement ? connection.prepareStatement(unitSQL, Statement.RETURN_GENERATED_KEYS) : connection.prepareStatement(unitSQL);
        PreparedStatement result = sqlStatement instanceof InsertStatement
                ? connection.prepareStatement(sqlExecutionUnit.getSqlUnit().getSql(), Statement.RETURN_GENERATED_KEYS) : connection.prepareStatement(sqlExecutionUnit.getSqlUnit().getSql());
        for (int i = 0; i < preparedStatementParameters.size(); i++) {
            result.setObject(i + 1, preparedStatementParameters.get(i).getValue());
        }
+3 −4
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ package io.shardingsphere.proxy.backend.common.jdbc.text;
import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.core.merger.QueryResult;
import io.shardingsphere.core.parsing.parser.sql.SQLStatement;
import io.shardingsphere.core.routing.SQLExecutionUnit;
import io.shardingsphere.core.routing.SQLRouteResult;
import io.shardingsphere.core.routing.StatementRoutingEngine;
import io.shardingsphere.proxy.backend.common.ProxyMode;
@@ -32,7 +33,6 @@ import io.shardingsphere.proxy.transport.common.packet.DatabaseProtocolPacket;
import io.shardingsphere.proxy.transport.mysql.packet.command.CommandResponsePackets;
import io.shardingsphere.proxy.transport.mysql.packet.command.text.query.TextResultSetRowPacket;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
@@ -66,9 +66,8 @@ public final class JDBCTextBackendHandler extends JDBCBackendHandler {
    }
    
    @Override
    protected Statement prepareResource(final String dataSourceName, final String unitSQL, final SQLStatement sqlStatement) throws SQLException {
        DataSource dataSource = ruleRegistry.getDataSourceMap().get(dataSourceName);
        Connection connection = getConnection(dataSource);
    protected Statement prepareResource(final SQLExecutionUnit sqlExecutionUnit, final SQLStatement sqlStatement) throws SQLException {
        Connection connection = getConnection(ruleRegistry.getDataSourceMap().get(sqlExecutionUnit.getDataSource()));
        Statement result = connection.createStatement();
        ProxyJDBCResource proxyJDBCResource = (ProxyJDBCResource) getJdbcResource();
        proxyJDBCResource.addConnection(connection);