Commit 19f2a681 authored by terrymanu's avatar terrymanu
Browse files

refactor MySQLResponseHandler

parent 28b18d9d
Loading
Loading
Loading
Loading
+44 −40
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ package io.shardingsphere.proxy.backend.netty.client.response.mysql;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.shardingsphere.core.exception.ShardingException;
import io.shardingsphere.core.metadata.datasource.DataSourceMetaData;
import io.shardingsphere.core.rule.DataSourceParameter;
import io.shardingsphere.proxy.backend.constant.AuthType;
@@ -69,16 +70,16 @@ public final class MySQLResponseHandler extends ResponseHandler {
        switch (authType) {
            case UN_AUTH:
                auth(context, payload);
                return;
                break;
            case AUTHING:
                authing(context, payload, header);
                return;
                break;
            case AUTH_SUCCESS:
                authSuccess(context, payload, header);
                return;
                break;
            case AUTH_FAILED:
                log.error("mysql auth failed, cannot handle channel read message");
                return;
                break;
            default:
                throw new UnsupportedOperationException(authType.name());
        }
@@ -110,6 +111,24 @@ public final class MySQLResponseHandler extends ResponseHandler {
        authType = AuthType.AUTHING;
    }
    
    private byte[] securePasswordAuthentication(final byte[] password, final byte[] authPluginData) {
        try {
            MessageDigest messageDigest = MessageDigest.getInstance("SHA-1");
            byte[] part1 = messageDigest.digest(password);
            messageDigest.reset();
            byte[] part2 = messageDigest.digest(part1);
            messageDigest.reset();
            messageDigest.update(authPluginData);
            byte[] result = messageDigest.digest(part2);
            for (int i = 0; i < result.length; i++) {
                result[i] = (byte) (result[i] ^ part1[i]);
            }
            return result;
        } catch (final NoSuchAlgorithmException ex) {
            throw new ShardingException(ex);
        }
    }
    
    private void authing(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) {
        if (OKPacket.HEADER == header) {
            okPacket(context, payload);
@@ -121,13 +140,17 @@ public final class MySQLResponseHandler extends ResponseHandler {
    }
    
    private void authSuccess(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) {
        if (EofPacket.HEADER == header) {
        switch (header) {
            case EofPacket.HEADER:
                eofPacket(context, payload);
        } else if (OKPacket.HEADER == header) {
                break;
            case OKPacket.HEADER:
                okPacket(context, payload);
        } else if (ErrPacket.HEADER == header) {
                break;
            case ErrPacket.HEADER:
                errPacket(context, payload);
        } else {
                break;
            default:
                commonPacket(context, payload);
        }
    }
@@ -174,11 +197,18 @@ public final class MySQLResponseHandler extends ResponseHandler {
        }
    }
    
    private void setResponse(final ChannelHandlerContext context) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        if (null != FutureRegistry.getInstance().get(connectionId)) {
            FutureRegistry.getInstance().get(connectionId).setResponse(resultMap.get(connectionId));
        }
    }
    
    @Override
    protected void commonPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId);
        if (mysqlQueryResult == null) {
        if (null == mysqlQueryResult) {
            mysqlQueryResult = new MySQLQueryResult(payload);
            resultMap.put(connectionId, mysqlQueryResult);
        } else if (mysqlQueryResult.needColumnDefinition()) {
@@ -188,35 +218,9 @@ public final class MySQLResponseHandler extends ResponseHandler {
        }
    }
    
    private byte[] securePasswordAuthentication(final byte[] password, final byte[] authPluginData) {
        try {
            MessageDigest messageDigest = MessageDigest.getInstance("SHA-1");
            byte[] part1 = messageDigest.digest(password);
            messageDigest.reset();
            byte[] part2 = messageDigest.digest(part1);
            messageDigest.reset();
            messageDigest.update(authPluginData);
            byte[] result = messageDigest.digest(part2);
            for (int i = 0; i < result.length; i++) {
                result[i] = (byte) (result[i] ^ part1[i]);
            }
            return result;
        } catch (final NoSuchAlgorithmException ex) {
            log.error(ex.getMessage(), ex);
        }
        return null;
    }
    
    private void setResponse(final ChannelHandlerContext context) {
        int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText());
        if (FutureRegistry.getInstance().get(connectionId) != null) {
            FutureRegistry.getInstance().get(connectionId).setResponse(resultMap.get(connectionId));
        }
    }
    
    @Override
    public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
        //TODO delete connection map.
        //TODO delete connection map
        super.channelInactive(ctx);
    }
}