Commit 0dababf0 authored by zhenwu's avatar zhenwu
Browse files

add pq cases

parent 8e4c5833
Loading
Loading
Loading
Loading
+123 −95
Original line number Diff line number Diff line
@@ -497,6 +497,7 @@ class TestIndexBase:
        status, ids = connect.add_vectors(table, vectors)
        for i in range(2):
            status = connect.create_index(table, index_params)

            assert status.OK()
            status, result = connect.describe_index(table)
            logging.getLogger().info(result)
@@ -584,6 +585,9 @@ class TestIndexIP:
        status = connect.create_partition(ip_table, partition_name, tag)
        status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag)
        status = connect.create_index(partition_name, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()

    @pytest.mark.level(2)
@@ -609,6 +613,9 @@ class TestIndexIP:
        logging.getLogger().info(index_params)
        status, ids = connect.add_vectors(ip_table, vectors)
        status = connect.create_index(ip_table, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            logging.getLogger().info(connect.describe_index(ip_table))
            query_vecs = [vectors[0], vectors[1], vectors[2]]
@@ -943,6 +950,9 @@ class TestIndexIP:
        index_params = get_index_params
        status, ids = connect.add_vectors(ip_table, vectors)
        status = connect.create_index(ip_table, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()        
            status, result = connect.describe_index(ip_table)
            logging.getLogger().info(result)
@@ -965,6 +975,9 @@ class TestIndexIP:
        status = connect.create_partition(ip_table, partition_name, tag)
        status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag)
        status = connect.create_index(ip_table, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            status, result = connect.describe_index(ip_table)
            logging.getLogger().info(result)
@@ -987,6 +1000,9 @@ class TestIndexIP:
        status = connect.create_partition(ip_table, partition_name, tag)
        status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag)
        status = connect.create_index(partition_name, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            status = connect.drop_index(ip_table)
            assert status.OK()
@@ -1012,6 +1028,9 @@ class TestIndexIP:
        status = connect.create_partition(ip_table, partition_name, tag)
        status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag)
        status = connect.create_index(partition_name, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            status = connect.drop_index(partition_name)
            assert status.OK()
@@ -1040,6 +1059,9 @@ class TestIndexIP:
        status = connect.create_partition(ip_table, new_partition_name, new_tag)
        status, ids = connect.add_vectors(ip_table, vectors)
        status = connect.create_index(ip_table, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            status = connect.drop_index(new_partition_name)
            assert status.OK()
@@ -1068,6 +1090,9 @@ class TestIndexIP:
        index_params = get_simple_index_params
        status, ids = connect.add_vectors(ip_table, vectors)
        status = connect.create_index(ip_table, index_params)
        if index_params["index_type"] == IndexType.IVF_PQ:
            assert not status.OK()
        else:
            assert status.OK()
            status, result = connect.describe_index(ip_table)
            logging.getLogger().info(result)
@@ -1120,6 +1145,9 @@ class TestIndexIP:
        status, ids = connect.add_vectors(ip_table, vectors)
        for i in range(2):
            status = connect.create_index(ip_table, index_params)
            if index_params["index_type"] == IndexType.IVF_PQ:
                assert not status.OK()
            else:
                assert status.OK()
                status, result = connect.describe_index(ip_table)
                logging.getLogger().info(result)
+2 −2
Original line number Diff line number Diff line
@@ -437,7 +437,7 @@ def gen_invalid_index_params():

def gen_index_params():
    index_params = []
    index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
    index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]
    nlists = [1, 16384, 50000]

    def gen_params(index_types, nlists):
@@ -450,7 +450,7 @@ def gen_index_params():

def gen_simple_index_params():
    index_params = []
    index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
    index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]
    nlists = [1024]

    def gen_params(index_types, nlists):