Commit 650bf54f authored by gaohongtao's avatar gaohongtao
Browse files

fix #11

parent 5d03d899
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -52,6 +52,8 @@ public final class AggregationInvokeHandler extends AbstractMergerInvokeHandler<
                return Optional.of(each);
            } else if (each.getAlias().isPresent() && each.getAlias().get().equals(resultSetQueryIndex.getQueryName())) {
                return Optional.of(each);
            } else if (each.getExpression().equalsIgnoreCase(resultSetQueryIndex.getQueryName())) {
                return Optional.of(each);
            }
        }
        return Optional.absent();
+103 −54
Original line number Diff line number Diff line
@@ -17,91 +17,140 @@

package com.dangdang.ddframe.rdb.sharding.merger.aggregation;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Test;

import com.dangdang.ddframe.rdb.sharding.merger.ResultSetFactory;
import com.dangdang.ddframe.rdb.sharding.merger.fixture.MockResultSet;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.AggregationColumn;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.AggregationColumn.AggregationType;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.MergeContext;
import com.google.common.base.Optional;
import lombok.RequiredArgsConstructor;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

@RunWith(Parameterized.class)
@RequiredArgsConstructor
public final class AggregationResultSetTest {
    
    @Test
    public void assertNextForSum() throws SQLException {
        ResultSet resultSet = ResultSetFactory.getResultSet(Arrays.<ResultSet>asList(
                new MockResultSet<Integer>(6), new MockResultSet<Integer>(2), new MockResultSet<Integer>()), createMergeContext(AggregationType.SUM));
        assertTrue(resultSet.next());
        assertThat(resultSet.getInt(1), is(8));
        assertFalse(resultSet.next());
    }
    private final TestTarget type;
    
    private final AggregationType aggregationType;
    
    private final List<String> columns;
    
    private final List<Integer> resultSet1;
    
    private final List<Integer> resultSet2;
    
    private final Optional<String> nameOfGetResult;
    
    private final Class<? extends Number> resultClass;
    
    private final Number result;
    
    @Parameterized.Parameters
    public static Collection init() {
        
        return Arrays.asList(new Object[][]{
                {TestTarget.INDEX, AggregationType.SUM, Arrays.asList(""), Arrays.asList(6), Arrays.asList(2), Optional.absent(), Integer.class, 8},
                {TestTarget.COLUMN_NAME, AggregationType.SUM, Arrays.asList("SUM(0)"), Arrays.asList(6), Arrays.asList(2), Optional.of("SUM(0)"), Integer.class, 8},
                {TestTarget.ALIAS, AggregationType.SUM, Arrays.asList("SUM_RESULT"), Arrays.asList(6), Arrays.asList(2), Optional.of("SUM_RESULT"), Integer.class, 8},
                {TestTarget.INDEX, AggregationType.COUNT, Arrays.asList(""), Arrays.asList(6), Arrays.asList(2), Optional.absent(), Integer.class, 8},
                {TestTarget.COLUMN_NAME, AggregationType.COUNT, Arrays.asList("COUNT(0)"), Arrays.asList(6), Arrays.asList(2), Optional.of("COUNT(0)"), Integer.class, 8},
                {TestTarget.ALIAS, AggregationType.COUNT, Arrays.asList("COUNT_RESULT"), Arrays.asList(6), Arrays.asList(2), Optional.of("COUNT_RESULT"), Integer.class, 8},
                {TestTarget.INDEX, AggregationType.MAX, Arrays.asList(""), Arrays.asList(6), Arrays.asList(2), Optional.absent(), Integer.class, 6},
                {TestTarget.COLUMN_NAME, AggregationType.MAX, Arrays.asList("MAX(0)"), Arrays.asList(6), Arrays.asList(2), Optional.of("MAX(0)"), Integer.class, 6},
                {TestTarget.ALIAS, AggregationType.MAX, Arrays.asList("MAX_RESULT"), Arrays.asList(6), Arrays.asList(2), Optional.of("MAX_RESULT"), Integer.class, 6},
                {TestTarget.INDEX, AggregationType.MIN, Arrays.asList(""), Arrays.asList(6), Arrays.asList(2), Optional.absent(), Integer.class, 2},
                {TestTarget.COLUMN_NAME, AggregationType.MIN, Arrays.asList("MIN(0)"), Arrays.asList(6), Arrays.asList(2), Optional.of("MIN(0)"), Integer.class, 2},
                {TestTarget.ALIAS, AggregationType.MIN, Arrays.asList("MIN_RESULT"), Arrays.asList(6), Arrays.asList(2), Optional.of("MIN_RESULT"), Integer.class, 2},
                {TestTarget.INDEX, AggregationType.AVG, Arrays.asList("sharding_gen_1", "sharding_gen_2"), Arrays.asList(5, 10), Arrays.asList(10, 100), Optional.absent(), Double.class, 7.3333D},
                {TestTarget.COLUMN_NAME, AggregationType.AVG, Arrays.asList("sharding_gen_1", "sharding_gen_2"), Arrays.asList(5, 10), Arrays.asList(10, 100), Optional.of("AVG(*)"), Double.class, 7.3333D},
                {TestTarget.ALIAS, AggregationType.AVG, Arrays.asList("sharding_gen_1", "sharding_gen_2"), Arrays.asList(5, 10), Arrays.asList(10, 100), Optional.of("AVG_RESULT"), Double.class, 7.3333D},
        });
        
    @Test
    public void assertNextForCount() throws SQLException {
        ResultSet resultSet = ResultSetFactory.getResultSet(Arrays.<ResultSet>asList(
                new MockResultSet<Integer>(6), new MockResultSet<Integer>(2), new MockResultSet<Integer>()), createMergeContext(AggregationType.COUNT));
        assertTrue(resultSet.next());
        assertThat(resultSet.getInt(1), is(8));
        assertFalse(resultSet.next());
    }
    
    @Test
    public void assertNextForMax() throws SQLException {
        ResultSet resultSet = ResultSetFactory.getResultSet(Arrays.<ResultSet>asList(
                new MockResultSet<Integer>(6), new MockResultSet<Integer>(2), new MockResultSet<Integer>()), createMergeContext(AggregationType.MAX));
        assertTrue(resultSet.next());
        assertThat(resultSet.getInt(1), is(6));
        assertFalse(resultSet.next());
    public void assertNext() throws SQLException {
        MergeContext mergeContext;
        switch (type) {
            case INDEX:
                mergeContext = createMergeContext(1, "column", null, aggregationType);
                break;
            case COLUMN_NAME:
                mergeContext = createMergeContext(1, nameOfGetResult.get(), null, aggregationType);
                break;
            case ALIAS:
                mergeContext = createMergeContext(1, "column", nameOfGetResult.get(), aggregationType);
                break;
            default:
                throw new RuntimeException();
        }
        
    @Test
    public void assertNextForMin() throws SQLException {
        ResultSet resultSet = ResultSetFactory.getResultSet(Arrays.<ResultSet>asList(
                new MockResultSet<Integer>(6), new MockResultSet<Integer>(2), new MockResultSet<Integer>()), createMergeContext(AggregationType.MIN));
                createMock(columns, resultSet1), createMock(columns, resultSet2), new MockResultSet<Integer>()),
                mergeContext);
        assertTrue(resultSet.next());
        assertThat(resultSet.getInt(1), is(2));
    
        Number actual;
        switch (type) {
            case INDEX:
                if (Integer.class.equals(resultClass)) {
                    actual = resultSet.getInt(1);
                } else if (Double.class.equals(resultClass)) {
                    actual = resultSet.getDouble(1);
                } else {
                    throw new RuntimeException();
                }
                break;
            default:
                if (Integer.class.equals(resultClass)) {
                    actual = resultSet.getInt(this.nameOfGetResult.get());
                } else if (Double.class.equals(resultClass)) {
                    actual = resultSet.getDouble(this.nameOfGetResult.get());
                } else {
                    throw new RuntimeException();
                }
                break;
        }
        assertThat(actual, is(result));
        assertFalse(resultSet.next());
    }
    
    @Test
    public void assertNextForAvg() throws SQLException {
        Map<String, Integer> map1 = new LinkedHashMap<>(2);
        map1.put("sharding_gen_1", 5);
        map1.put("sharding_gen_2", 10);
        Map<String, Integer> map2 = new LinkedHashMap<>(2);
        map2.put("sharding_gen_1", 10);
        map2.put("sharding_gen_2", 100);
        ResultSet resultSet = ResultSetFactory.getResultSet(Arrays.<ResultSet>asList(
                new MockResultSet<Integer>(Arrays.asList(map1)), new MockResultSet<Integer>(Arrays.asList(map2)), new MockResultSet<Integer>()), 
                createMergeContext(AggregationType.AVG, createDerivedColumn(1, AggregationType.COUNT), createDerivedColumn(2, AggregationType.SUM)));
        assertTrue(resultSet.next());
        assertThat(resultSet.getDouble(1), is(7.3333D));
        assertFalse(resultSet.next());
    private MockResultSet<Integer> createMock(List<String> columns, List<Integer> values) {
        Map<String, Integer> result = new HashMap<>();
        for (int i = 0; i < columns.size(); i++) {
            result.put(columns.get(i), values.get(i));            
        }
        return new MockResultSet<>(Arrays.asList(result));
    }
    
    private MergeContext createMergeContext(final AggregationType aggregationType, final AggregationColumn... derivedColumns) {
        AggregationColumn column = new AggregationColumn("column", aggregationType, Optional.<String>absent(), Optional.<String>absent(), 1);
        for (AggregationColumn each : derivedColumns) {
            column.getDerivedColumns().add(each);
    private MergeContext createMergeContext(final int index, final String name, final String alias, final AggregationType aggregationType) {
        AggregationColumn column = new AggregationColumn(name, aggregationType, Optional.fromNullable(alias), Optional.<String>absent(), index);
        if (AggregationType.AVG.equals(aggregationType)) {
            column.getDerivedColumns().add(new AggregationColumn("column", AggregationType.COUNT, Optional.of("sharding_gen_1"), Optional.<String>absent()));
            column.getDerivedColumns().add(new AggregationColumn("column", AggregationType.COUNT, Optional.of("sharding_gen_2"), Optional.<String>absent()));
        }
        MergeContext result = new MergeContext();
        result.getAggregationColumns().add(column);
        return result;
    }
    
    private AggregationColumn createDerivedColumn(final int index, final AggregationType aggregationType) {
        return new AggregationColumn("column", aggregationType, Optional.of("sharding_gen_" + index), Optional.<String>absent());
    private enum TestTarget {
        INDEX, COLUMN_NAME, ALIAS
    }
}