Loading sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/constant/AuthType.java→sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/AuthType.java +1 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ * </p> */ package io.shardingsphere.proxy.backend.constant; package io.shardingsphere.proxy.backend.netty.client.response; /** * Auth Type. Loading sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/ResponseHandler.java +6 −6 Original line number Diff line number Diff line Loading @@ -17,9 +17,9 @@ package io.shardingsphere.proxy.backend.netty.client.response; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload; /** * SQL executed response handler. Loading @@ -29,13 +29,13 @@ import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload; */ public abstract class ResponseHandler extends ChannelInboundHandlerAdapter { protected abstract void auth(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void auth(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void eofPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void eofPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void okPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void okPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void errPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void errPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void commonPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void commonPacket(ChannelHandlerContext context, ByteBuf byteBuf); } sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/mysql/MySQLResponseHandler.java +37 −33 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ import io.netty.channel.ChannelHandlerContext; import io.shardingsphere.core.exception.ShardingException; import io.shardingsphere.core.metadata.datasource.DataSourceMetaData; import io.shardingsphere.core.rule.DataSourceParameter; import io.shardingsphere.proxy.backend.constant.AuthType; import io.shardingsphere.proxy.backend.netty.client.response.AuthType; import io.shardingsphere.proxy.backend.netty.client.response.ResponseHandler; import io.shardingsphere.proxy.backend.netty.future.FutureRegistry; import io.shardingsphere.proxy.config.RuleRegistry; Loading Loading @@ -55,27 +55,34 @@ import java.util.Map; @RequiredArgsConstructor public final class MySQLResponseHandler extends ResponseHandler { private static final RuleRegistry RULE_REGISTRY = RuleRegistry.getInstance(); private final DataSourceParameter dataSourceParameter; private final String dataSourceName; private final DataSourceMetaData dataSourceMetaData; private final Map<Integer, MySQLQueryResult> resultMap = new HashMap<>(); private final Map<Integer, MySQLQueryResult> resultMap; private AuthType authType = AuthType.UN_AUTH; private AuthType authType; public MySQLResponseHandler(final String dataSourceName) { dataSourceParameter = RuleRegistry.getInstance().getDataSourceConfigurationMap().get(dataSourceName); dataSourceMetaData = RuleRegistry.getInstance().getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName); resultMap = new HashMap<>(); authType = AuthType.UN_AUTH; } @Override public void channelRead(final ChannelHandlerContext context, final Object message) { MySQLPacketPayload payload = new MySQLPacketPayload((ByteBuf) message); int header = getHeader(payload); ByteBuf byteBuf = (ByteBuf) message; int header = getHeader(byteBuf); switch (authType) { case UN_AUTH: auth(context, payload); auth(context, byteBuf); break; case AUTHING: authing(context, payload, header); authing(context, byteBuf, header); break; case AUTH_SUCCESS: authSuccess(context, payload, header); authSuccess(context, byteBuf, header); break; case AUTH_FAILED: log.error("mysql auth failed, cannot handle channel read message"); Loading @@ -85,7 +92,8 @@ public final class MySQLResponseHandler extends ResponseHandler { } } private int getHeader(final MySQLPacketPayload payload) { private int getHeader(final ByteBuf byteBuf) { MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); payload.getByteBuf().markReaderIndex(); payload.readInt1(); int result = payload.readInt1(); Loading @@ -94,10 +102,8 @@ public final class MySQLResponseHandler extends ResponseHandler { } @Override protected void auth(final ChannelHandlerContext context, final MySQLPacketPayload payload) { try { DataSourceParameter dataSourceParameter = RULE_REGISTRY.getDataSourceConfigurationMap().get(dataSourceName); DataSourceMetaData dataSourceMetaData = RULE_REGISTRY.getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName); protected void auth(final ChannelHandlerContext context, final ByteBuf byteBuf) { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { HandshakePacket handshakePacket = new HandshakePacket(payload); byte[] authResponse = securePasswordAuthentication(dataSourceParameter.getPassword().getBytes(), handshakePacket.getAuthPluginData().getAuthPluginData()); HandshakeResponse41Packet handshakeResponse41Packet = new HandshakeResponse41Packet( Loading @@ -105,8 +111,6 @@ public final class MySQLResponseHandler extends ResponseHandler { dataSourceParameter.getUsername(), authResponse, dataSourceMetaData.getSchemeName()); ChannelRegistry.getInstance().putConnectionId(context.channel().id().asShortText(), handshakePacket.getConnectionId()); context.writeAndFlush(handshakeResponse41Packet); } finally { payload.close(); } authType = AuthType.AUTHING; } Loading @@ -129,64 +133,63 @@ public final class MySQLResponseHandler extends ResponseHandler { } } private void authing(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) { private void authing(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) { if (OKPacket.HEADER == header) { okPacket(context, payload); okPacket(context, byteBuf); authType = AuthType.AUTH_SUCCESS; } else { errPacket(context, payload); errPacket(context, byteBuf); authType = AuthType.AUTH_FAILED; } } private void authSuccess(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) { private void authSuccess(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) { switch (header) { case EofPacket.HEADER: eofPacket(context, payload); eofPacket(context, byteBuf); break; case OKPacket.HEADER: okPacket(context, payload); okPacket(context, byteBuf); break; case ErrPacket.HEADER: errPacket(context, payload); errPacket(context, byteBuf); break; default: commonPacket(context, payload); commonPacket(context, byteBuf); } } @Override protected void okPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void okPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); try { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { MySQLQueryResult mysqlQueryResult = new MySQLQueryResult(); mysqlQueryResult.setGenericResponse(new OKPacket(payload)); resultMap.put(connectionId, mysqlQueryResult); setResponse(context); } finally { resultMap.remove(connectionId); payload.close(); } } @Override protected void errPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void errPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); try { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { MySQLQueryResult mysqlQueryResult = new MySQLQueryResult(); mysqlQueryResult.setGenericResponse(new ErrPacket(payload)); resultMap.put(connectionId, mysqlQueryResult); setResponse(context); } finally { resultMap.remove(connectionId); payload.close(); } } @Override protected void eofPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void eofPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId); MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); if (mysqlQueryResult.isColumnFinished()) { mysqlQueryResult.setRowFinished(new EofPacket(payload)); resultMap.remove(connectionId); Loading @@ -205,9 +208,10 @@ public final class MySQLResponseHandler extends ResponseHandler { } @Override protected void commonPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void commonPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId); MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); if (null == mysqlQueryResult) { mysqlQueryResult = new MySQLQueryResult(payload); resultMap.put(connectionId, mysqlQueryResult); Loading Loading
sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/constant/AuthType.java→sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/AuthType.java +1 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ * </p> */ package io.shardingsphere.proxy.backend.constant; package io.shardingsphere.proxy.backend.netty.client.response; /** * Auth Type. Loading
sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/ResponseHandler.java +6 −6 Original line number Diff line number Diff line Loading @@ -17,9 +17,9 @@ package io.shardingsphere.proxy.backend.netty.client.response; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload; /** * SQL executed response handler. Loading @@ -29,13 +29,13 @@ import io.shardingsphere.proxy.transport.mysql.packet.MySQLPacketPayload; */ public abstract class ResponseHandler extends ChannelInboundHandlerAdapter { protected abstract void auth(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void auth(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void eofPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void eofPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void okPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void okPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void errPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void errPacket(ChannelHandlerContext context, ByteBuf byteBuf); protected abstract void commonPacket(ChannelHandlerContext context, MySQLPacketPayload payload); protected abstract void commonPacket(ChannelHandlerContext context, ByteBuf byteBuf); }
sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/netty/client/response/mysql/MySQLResponseHandler.java +37 −33 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ import io.netty.channel.ChannelHandlerContext; import io.shardingsphere.core.exception.ShardingException; import io.shardingsphere.core.metadata.datasource.DataSourceMetaData; import io.shardingsphere.core.rule.DataSourceParameter; import io.shardingsphere.proxy.backend.constant.AuthType; import io.shardingsphere.proxy.backend.netty.client.response.AuthType; import io.shardingsphere.proxy.backend.netty.client.response.ResponseHandler; import io.shardingsphere.proxy.backend.netty.future.FutureRegistry; import io.shardingsphere.proxy.config.RuleRegistry; Loading Loading @@ -55,27 +55,34 @@ import java.util.Map; @RequiredArgsConstructor public final class MySQLResponseHandler extends ResponseHandler { private static final RuleRegistry RULE_REGISTRY = RuleRegistry.getInstance(); private final DataSourceParameter dataSourceParameter; private final String dataSourceName; private final DataSourceMetaData dataSourceMetaData; private final Map<Integer, MySQLQueryResult> resultMap = new HashMap<>(); private final Map<Integer, MySQLQueryResult> resultMap; private AuthType authType = AuthType.UN_AUTH; private AuthType authType; public MySQLResponseHandler(final String dataSourceName) { dataSourceParameter = RuleRegistry.getInstance().getDataSourceConfigurationMap().get(dataSourceName); dataSourceMetaData = RuleRegistry.getInstance().getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName); resultMap = new HashMap<>(); authType = AuthType.UN_AUTH; } @Override public void channelRead(final ChannelHandlerContext context, final Object message) { MySQLPacketPayload payload = new MySQLPacketPayload((ByteBuf) message); int header = getHeader(payload); ByteBuf byteBuf = (ByteBuf) message; int header = getHeader(byteBuf); switch (authType) { case UN_AUTH: auth(context, payload); auth(context, byteBuf); break; case AUTHING: authing(context, payload, header); authing(context, byteBuf, header); break; case AUTH_SUCCESS: authSuccess(context, payload, header); authSuccess(context, byteBuf, header); break; case AUTH_FAILED: log.error("mysql auth failed, cannot handle channel read message"); Loading @@ -85,7 +92,8 @@ public final class MySQLResponseHandler extends ResponseHandler { } } private int getHeader(final MySQLPacketPayload payload) { private int getHeader(final ByteBuf byteBuf) { MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); payload.getByteBuf().markReaderIndex(); payload.readInt1(); int result = payload.readInt1(); Loading @@ -94,10 +102,8 @@ public final class MySQLResponseHandler extends ResponseHandler { } @Override protected void auth(final ChannelHandlerContext context, final MySQLPacketPayload payload) { try { DataSourceParameter dataSourceParameter = RULE_REGISTRY.getDataSourceConfigurationMap().get(dataSourceName); DataSourceMetaData dataSourceMetaData = RULE_REGISTRY.getMetaData().getDataSource().getActualDataSourceMetaData(dataSourceName); protected void auth(final ChannelHandlerContext context, final ByteBuf byteBuf) { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { HandshakePacket handshakePacket = new HandshakePacket(payload); byte[] authResponse = securePasswordAuthentication(dataSourceParameter.getPassword().getBytes(), handshakePacket.getAuthPluginData().getAuthPluginData()); HandshakeResponse41Packet handshakeResponse41Packet = new HandshakeResponse41Packet( Loading @@ -105,8 +111,6 @@ public final class MySQLResponseHandler extends ResponseHandler { dataSourceParameter.getUsername(), authResponse, dataSourceMetaData.getSchemeName()); ChannelRegistry.getInstance().putConnectionId(context.channel().id().asShortText(), handshakePacket.getConnectionId()); context.writeAndFlush(handshakeResponse41Packet); } finally { payload.close(); } authType = AuthType.AUTHING; } Loading @@ -129,64 +133,63 @@ public final class MySQLResponseHandler extends ResponseHandler { } } private void authing(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) { private void authing(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) { if (OKPacket.HEADER == header) { okPacket(context, payload); okPacket(context, byteBuf); authType = AuthType.AUTH_SUCCESS; } else { errPacket(context, payload); errPacket(context, byteBuf); authType = AuthType.AUTH_FAILED; } } private void authSuccess(final ChannelHandlerContext context, final MySQLPacketPayload payload, final int header) { private void authSuccess(final ChannelHandlerContext context, final ByteBuf byteBuf, final int header) { switch (header) { case EofPacket.HEADER: eofPacket(context, payload); eofPacket(context, byteBuf); break; case OKPacket.HEADER: okPacket(context, payload); okPacket(context, byteBuf); break; case ErrPacket.HEADER: errPacket(context, payload); errPacket(context, byteBuf); break; default: commonPacket(context, payload); commonPacket(context, byteBuf); } } @Override protected void okPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void okPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); try { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { MySQLQueryResult mysqlQueryResult = new MySQLQueryResult(); mysqlQueryResult.setGenericResponse(new OKPacket(payload)); resultMap.put(connectionId, mysqlQueryResult); setResponse(context); } finally { resultMap.remove(connectionId); payload.close(); } } @Override protected void errPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void errPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); try { try (MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf)) { MySQLQueryResult mysqlQueryResult = new MySQLQueryResult(); mysqlQueryResult.setGenericResponse(new ErrPacket(payload)); resultMap.put(connectionId, mysqlQueryResult); setResponse(context); } finally { resultMap.remove(connectionId); payload.close(); } } @Override protected void eofPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void eofPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId); MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); if (mysqlQueryResult.isColumnFinished()) { mysqlQueryResult.setRowFinished(new EofPacket(payload)); resultMap.remove(connectionId); Loading @@ -205,9 +208,10 @@ public final class MySQLResponseHandler extends ResponseHandler { } @Override protected void commonPacket(final ChannelHandlerContext context, final MySQLPacketPayload payload) { protected void commonPacket(final ChannelHandlerContext context, final ByteBuf byteBuf) { int connectionId = ChannelRegistry.getInstance().getConnectionId(context.channel().id().asShortText()); MySQLQueryResult mysqlQueryResult = resultMap.get(connectionId); MySQLPacketPayload payload = new MySQLPacketPayload(byteBuf); if (null == mysqlQueryResult) { mysqlQueryResult = new MySQLQueryResult(payload); resultMap.put(connectionId, mysqlQueryResult); Loading