Commit c0fd00bc authored by terrymanu's avatar terrymanu
Browse files

refactor ResponseHandler

parent 19f2a681
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@
 * </p>
 */

package io.shardingsphere.proxy.backend.constant;
package io.shardingsphere.proxy.backend.netty.client.response;

/**
 * Auth Type.
+6 −6
Original line number Diff line number Diff line
@@ -17,9 +17,9 @@

package io.shardingsphere.proxy.backend.netty.client.response;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload;

/**
 * SQL executed response handler.
@@ -29,13 +29,13 @@ import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload;
 */
public abstract class ResponseHandler extends ChannelInboundHandlerAdapter {
    
    protected abstract void auth(ChannelHandlerContext context, MySQLPacketPayload payload);
    protected abstract void auth(ChannelHandlerContext context, ByteBuf byteBuf);
    
    protected abstract void eofPacket(ChannelHandlerContext context, MySQLPacketPayload payload);
    protected abstract void eofPacket(ChannelHandlerContext context, ByteBuf byteBuf);
    
    protected abstract void okPacket(ChannelHandlerContext context, MySQLPacketPayload payload);
    protected abstract void okPacket(ChannelHandlerContext context, ByteBuf byteBuf);
    
    protected abstract void errPacket(ChannelHandlerContext context, MySQLPacketPayload payload);
    protected abstract void errPacket(ChannelHandlerContext context, ByteBuf byteBuf);
    
    protected abstract void commonPacket(ChannelHandlerContext context, MySQLPacketPayload payload);
    protected abstract void commonPacket(ChannelHandlerContext context, ByteBuf byteBuf);
}
+37 −33
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.shardingsphere.core.exception.ShardingException;
import io.shardingsphere.core.metadata.datasource.DataSourceMetaData;
import io.shardingsphere.core.rule.DataSourceParameter;
import io.shardingsphere.proxy.backend.constant.AuthType;
import io.shardingsphere.proxy.backend.netty.client.response.AuthType;
import io.shardingsphere.proxy.backend.netty.client.response.ResponseHandler;
import io.shardingsphere.proxy.backend.netty.future.FutureRegistry;
import io.shardingsphere.proxy.config.RuleRegistry;
@@ -55,27 +55,34 @@ import java.util.Map;
@RequiredArgsConstructor
public final class MySQLResponseHandler extends ResponseHandler {
    
    private static final RuleRegistry RULE_REGISTRY = RuleRegistry.getInstance();
    private final DataSourceParameter dataSourceParameter;
    
    private final String dataSourceName;
    private final DataSourceMetaData dataSourceMetaData;
    
    private final Map<Integer, MySQLQueryResult> resultMap = new HashMap<>();
    private final Map<Integer, MySQLQueryResult> resultMap;
    
    private AuthType authType = AuthType.UN_AUTH;
    private AuthType authType;
    
    public MySQLResponseHandler(final String dataSourceName) {
        dataSourceParameter = RuleRegistry.getInstance().getDataSourceConfigurationMap().get(dataSourceName);
        dataSourceMetaData = RuleRegistry.getInstance().getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName);
        resultMap = new HashMap<>();
        authType = AuthType.UN_AUTH;
    }
    
    @Override
    public void channelRead(final ChannelHandlerContext context, final Object message) {
        MySQLPacketPayload payload = new MySQLPacketPayload((ByteBuf) message);
        int header = getHeader(payload);
        ByteBuf byteBuf = (ByteBuf) message;
        int header = getHeader(byteBuf);
        switch (authType) {
            case UN_AUTH:
                auth(context, payload);
                auth(context, byteBuf);
                break;
            case AUTHING:
                authing(context, payload, header);
                authing(context, byteBuf, header);
                break;
            case AUTH_SUCCESS:
                authSuccess(context, payload, header);
                authSuccess(context, byteBuf, header);
                break;
            case AUTH_FAILED:
                log.error("mysql auth failed, cannot handle channel read message");
@@ -85,7 +92,8 @@ public final class MySQLResponseHandler extends ResponseHandler {
        }
    }
    
    private int getHeader(final MySQLPacketPayload payload) {
    private int getHeader(final ByteBuf byteBuf) {
        MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf);
        payload.getByteBuf().markReaderIndex();
        payload.readInt1();
        int result = payload.readInt1();
@@ -94,10 +102,8 @@ public final class MySQLResponseHandler extends ResponseHandler {
    }
    
    @Override
    protected void auth(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
        try {
            DataSourceParameter dataSourceParameter = RULE_REGISTRY.getDataSourceConfigurationMap().get(dataSourceName);
            DataSourceMetaData dataSourceMetaData = RULE_REGISTRY.getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName);
    protected void auth(final ChannelHandlerContext context, final ByteBuf byteBuf) {
        try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) {
            HandshakePacket handshakePacket = new HandshakePacket(payload);
            byte[] authResponse = securePasswordAuthentication(dataSourceParameter.getPassword().getBytes(), handshakePacket.getAuthPluginData().getAuthPluginData());
            HandshakeResponse41Packet handshakeResponse41Packet = new HandshakeResponse41Packet(
@@ -105,8 +111,6 @@ public final class MySQLResponseHandler extends ResponseHandler {
                    dataSourceParameter.getUsername(), authResponse, dataSourceMetaData.getSchemeName());
            ChannelRegistry.getInstance().putConnectionId(context.channel().id().asShortText(), handshakePacket.getConnectionId());
            context.writeAndFlush(handshakeResponse41Packet);
        } finally {
            payload.close();
        }
        authType = AuthType.AUTHING;
    }
@@ -129,64 +133,63 @@ public final class MySQLResponseHandler extends ResponseHandler {
        }
    }
    
    private void authing(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) {
    private void authing(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) {
        if (OKPacket.HEADER == header) {
            okPacket(context, payload);
            okPacket(context, byteBuf);
            authType = AuthType.AUTH_SUCCESS;
        } else {
            errPacket(context, payload);
            errPacket(context, byteBuf);
            authType = AuthType.AUTH_FAILED;
        }
    }
    
    private void authSuccess(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) {
    private void authSuccess(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) {
        switch (header) {
            case EofPacket.HEADER:
                eofPacket(context, payload);
                eofPacket(context, byteBuf);
                break;
            case OKPacket.HEADER:
                okPacket(context, payload);
                okPacket(context, byteBuf);
                break;
            case ErrPacket.HEADER:
                errPacket(context, payload);
                errPacket(context, byteBuf);
                break;
            default:
                commonPacket(context, payload);
                commonPacket(context, byteBuf);
        }
    }
    
    @Override
    protected void okPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
    protected void okPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        try {
        try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) {
            MySQLQueryResult mysqlQueryResult = new MySQLQueryResult();
            mysqlQueryResult.setGenericResponse(new OKPacket(payload));
            resultMap.put(connectionId, mysqlQueryResult);
            setResponse(context);
        } finally {
            resultMap.remove(connectionId);
            payload.close();
        }
    }
    
    @Override
    protected void errPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
    protected void errPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        try {
        try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) {
            MySQLQueryResult mysqlQueryResult = new MySQLQueryResult();
            mysqlQueryResult.setGenericResponse(new ErrPacket(payload));
            resultMap.put(connectionId, mysqlQueryResult);
            setResponse(context);
        } finally {
            resultMap.remove(connectionId);
            payload.close();
        }
    }
    
    @Override
    protected void eofPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
    protected void eofPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId);
        MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf);
        if (mysqlQueryResult.isColumnFinished()) {
            mysqlQueryResult.setRowFinished(new EofPacket(payload));
            resultMap.remove(connectionId);
@@ -205,9 +208,10 @@ public final class MySQLResponseHandler extends ResponseHandler {
    }
    
    @Override
    protected void commonPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
    protected void commonPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId);
        MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf);
        if (null == mysqlQueryResult) {
            mysqlQueryResult = new MySQLQueryResult(payload);
            resultMap.put(connectionId, mysqlQueryResult);