Unverified Commit c8f1e70b authored by 张亮's avatar 张亮 Committed by GitHub
Browse files

Merge pull request #656 from haocao/dev-sharding-proxy

Add real affected rows and last insert id into OKPacket.
parents 624dbad8 2992373e
Loading
Loading
Loading
Loading
+26 −4
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ package io.shardingjdbc.proxy.transport.mysql.packet.command.query;

import io.shardingjdbc.core.parsing.SQLJudgeEngine;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
import io.shardingjdbc.core.parsing.parser.sql.dml.insert.InsertStatement;
import io.shardingjdbc.proxy.backend.DataSourceManager;
import io.shardingjdbc.proxy.constant.ColumnType;
import io.shardingjdbc.proxy.constant.StatusFlag;
@@ -66,13 +67,20 @@ public final class ComQueryPacket extends CommandPacket {
                Statement statement = conn.createStatement()) {
            SQLStatement sqlStatement = new SQLJudgeEngine(sql).judge();
            ResultSet resultSet;
            int affectedRows = 0;
            long lastInsertId = 0;
            switch (sqlStatement.getType()) {
                case DQL:
                    resultSet = statement.executeQuery(sql);
                    break;
                case DML:
                case DDL:
                    statement.executeUpdate(sql);
                    if (isNeedGeneratedKey(sqlStatement)) {
                        affectedRows = statement.executeUpdate(sql, Statement.RETURN_GENERATED_KEYS);
                        lastInsertId = getGeneratedKey(statement);
                    } else {
                        affectedRows = statement.executeUpdate(sql);
                    }
                    resultSet = statement.getResultSet();
                    break;
                default:
@@ -81,7 +89,7 @@ public final class ComQueryPacket extends CommandPacket {
                    break;
            }
            if (null == resultSet) {
                result.add(new OKPacket(++currentSequenceId, 0, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
                result.add(new OKPacket(++currentSequenceId, affectedRows, lastInsertId, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
                return result;
            }
            ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
@@ -120,4 +128,18 @@ public final class ComQueryPacket extends CommandPacket {
        }
        return result;
    }
    
    private boolean isNeedGeneratedKey(final SQLStatement statement) {
        // TODO justify based on the request protocol
        return statement instanceof InsertStatement;
    }

    private long getGeneratedKey(final Statement statement) throws SQLException {
        long result = -1;
        ResultSet resultSet = statement.getGeneratedKeys();
        if (resultSet.next()) {
            result = resultSet.getLong(1);
        }
        return result;
    }
}