Unverified Commit d993bcc9 authored by del-zhenwu's avatar del-zhenwu Committed by GitHub
Browse files

remove todo case (#3296)



* remove todo case

Signed-off-by: default avatarzw <zw@milvus.io>

* remove todo case

Signed-off-by: default avatarzw <zw@milvus.io>

Co-authored-by: default avatarzw <zw@milvus.io>
parent 2fc5dc3b
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ collection_id = "count_collection"
add_interval_time = 3
segment_row_count = 5000
default_fields = gen_default_fields()
default_binary_fields = gen_binary_default_fields()
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
field_name = "fload_vector"
@@ -493,8 +494,8 @@ class TestCollectionMultiCollections:
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

    # TODO:
    def _test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
    @pytest.mark.level(2)
    def test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
        '''
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
@@ -503,21 +504,20 @@ class TestCollectionMultiCollections:
        '''
        raw_vectors, entities = gen_binary_entities(insert_count)
        res = connect.insert(binary_collection, entities)
        # logging.getLogger().info(entities)
        collection_list = []
        collection_num = 20
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            connect.create_collection(collection_name, default_binary_fields)
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

    # TODO:
    def _test_collection_count_multi_collections_mix(self, connect):
    @pytest.mark.level(2)
    def test_collection_count_multi_collections_mix(self, connect):
        '''
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
@@ -534,7 +534,7 @@ class TestCollectionMultiCollections:
        for i in range(int(collection_num / 2), collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            connect.create_collection(collection_name, default_binary_fields)
            res = connect.insert(collection_name, binary_entities)
        connect.flush(collection_list)
        for i in range(collection_num):
+29 −24
Original line number Diff line number Diff line
@@ -134,9 +134,8 @@ class TestStatsBase:
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["row_count"] == nb - 2
        assert stats["partitions"][0]["row_count"] == nb -2
        assert stats["partitions"][0]["segments"][0]["data_size"] > 0
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_type"] == "FLAT"

    def test_get_collection_stats_after_compact_parts(self, connect, collection):
        '''
@@ -228,10 +227,11 @@ class TestStatsBase:
        connect.flush([collection])
        connect.create_index(collection, field_name, get_simple_index)
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
        assert stats["row_count"] == nb
        for file in stats["partitions"][0]["segments"][0]["files"]:
            if file["field"] == field_name and file["name"] != "_raw":
                assert file["data_size"] > 0
                assert file["index_type"] == get_simple_index["index_type"]

    def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
        '''
@@ -245,10 +245,11 @@ class TestStatsBase:
        get_simple_index.update({"metric_type": "IP"})
        connect.create_index(collection, field_name, get_simple_index)
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
        assert stats["row_count"] == nb
        for file in stats["partitions"][0]["segments"][0]["files"]:
            if file["field"] == field_name and file["name"] != "_raw":
                assert file["data_size"] > 0
                assert file["index_type"] == get_simple_index["index_type"]

    def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
        '''
@@ -260,10 +261,11 @@ class TestStatsBase:
        connect.flush([binary_collection])
        connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
        stats = connect.get_collection_stats(binary_collection)
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_jaccard_index["index_type"]
        assert stats["row_count"] == nb
        for file in stats["partitions"][0]["segments"][0]["files"]:
            if file["field"] == field_name and file["name"] != "_raw":
                assert file["data_size"] > 0
                assert file["index_type"] == get_simple_index["index_type"]

    def test_get_collection_stats_after_create_different_index(self, connect, collection):
        '''
@@ -276,10 +278,11 @@ class TestStatsBase:
        for index_type in ["IVF_FLAT", "IVF_SQ8"]:
            connect.create_index(collection, field_name, {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
            stats = connect.get_collection_stats(collection)
            logging.getLogger().info(stats)
            # TODO
            # assert stats["partitions"][0]["segments"][0]["index_name"] == index_type
            assert stats["partitions"][0]["segments"][0]["row_count"] == nb
            assert stats["row_count"] == nb
            for file in stats["partitions"][0]["segments"][0]["files"]:
                if file["field"] == field_name and file["name"] != "_raw":
                    assert file["data_size"] > 0
                    assert file["index_type"] == index_type

    def test_collection_count_multi_collections(self, connect):
        '''
@@ -323,10 +326,12 @@ class TestStatsBase:
                connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT","params":{ "nlist": 1024}, "metric_type": "L2"})
        for i in range(collection_num):
            stats = connect.get_collection_stats(collection_list[i])
            assert stats["partitions"][0]["segments"][0]["row_count"] == nb
            # TODO
            # if i % 2:
            #     assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_SQ8"
            # else:
            #     assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_FLAT"
            if i % 2:
                for file in stats["partitions"][0]["segments"][0]["files"]:
                    if file["field"] == field_name and file["name"] != "_raw":
                        assert file["index_type"] == "IVF_SQ8"
            else:
                for file in stats["partitions"][0]["segments"][0]["files"]:
                    if file["field"] == field_name and file["name"] != "_raw":
                        assert file["index_type"] == "IVF_FLAT"
            connect.drop_collection(collection_list[i])
+0 −15
Original line number Diff line number Diff line
@@ -65,7 +65,6 @@ class TestCreateCollection:
        connect.create_collection(collection_name, fields)
        assert connect.has_collection(collection_name)

    # TODO
    def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
        '''
        target: test create normal collection with different fields
@@ -298,20 +297,6 @@ class TestCreateCollectionInvalid(object):
        logging.getLogger().info(res)
        assert res["segment_row_count"] == default_segment_row_count

    # def _test_create_collection_no_metric_type(self, connect):
    #     '''
    #     target: test create collection with no metric_type params
    #     method: create collection with corrent params
    #     expected: use default L2
    #     '''
    #     collection_name = gen_unique_str(collection_id)
    #     fields = copy.deepcopy(default_fields)
    #     fields["fields"][-1]["params"].pop("metric_type")
    #     connect.create_collection(collection_name, fields)
    #     res = connect.get_collection_info(collection_name)
    #     logging.getLogger().info(res)
    #     assert res["metric_type"] == "L2"

    # TODO: assert exception
    def test_create_collection_limit_fields(self, connect):
        collection_name = gen_unique_str(collection_id)
+31 −18
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import threading
from multiprocessing import Process
from utils import *

nb = 1000
collection_id = "info"
default_fields = gen_default_fields() 
segment_row_count = 5000
@@ -53,7 +54,6 @@ class TestInfoBase:
    ******************************************************************
    """
  
    # TODO
    def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
        '''
        target: test create normal collection with different fields, check info returned
@@ -69,13 +69,16 @@ class TestInfoBase:
        }
        connect.create_collection(collection_name, fields)
        res = connect.get_collection_info(collection_name)
        # assert field_name
        # assert field_type
        # assert vector field params
        # assert metric type
        # assert dimension
        assert res['auto_id'] == True
        assert res['segment_row_count'] == segment_row_count
        assert len(res["fields"]) == 3
        for field in res["fields"]:
            if field["type"] == filter_field:
                assert field["name"] == filter_field["name"]
            elif field["type"] == vector_field:
                assert field["name"] == vector_field["name"]
                assert field["params"] == vector_field["params"]

    # TODO
    def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
        '''
        target: test create normal collection with different fields
@@ -86,7 +89,9 @@ class TestInfoBase:
        fields = copy.deepcopy(default_fields)
        fields["segment_row_count"] = get_segment_row_count
        connect.create_collection(collection_name, fields)
        # assert segment size
        # assert segment row count
        res = connect.get_collection_info(collection_name)
        assert res['segment_row_count'] == get_segment_row_count

    def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
        connect.create_index(collection, field_name, get_simple_index)
@@ -148,7 +153,6 @@ class TestInfoBase:
    ******************************************************************
    """

    # TODO
    def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field):
        '''
        target: test create normal collection with different fields, check info returned
@@ -163,15 +167,20 @@ class TestInfoBase:
                "segment_row_count": segment_row_count
        }
        connect.create_collection(collection_name, fields)
        # insert
        entities = gen_entities_by_fields(fields["fields"], nb, vector_field["params"]["dim"])
        res_ids = connect.insert(collection_name, entities)
        connect.flush([collection_name])
        res = connect.get_collection_info(collection_name)
        # assert field_name
        # assert field_type
        # assert vector field params
        # assert metric type
        # assert dimension
        assert res['auto_id'] == True
        assert res['segment_row_count'] == segment_row_count
        assert len(res["fields"]) == 3
        for field in res["fields"]:
            if field["type"] == filter_field:
                assert field["name"] == filter_field["name"]
            elif field["type"] == vector_field:
                assert field["name"] == vector_field["name"]
                assert field["params"] == vector_field["params"]

    # TODO
    def test_create_collection_segment_row_count_after_insert(self, connect, get_segment_row_count):
        '''
        target: test create normal collection with different fields
@@ -182,8 +191,12 @@ class TestInfoBase:
        fields = copy.deepcopy(default_fields)
        fields["segment_row_count"] = get_segment_row_count
        connect.create_collection(collection_name, fields)
        # insert
        # assert segment size
        entities = gen_entities_by_fields(fields["fields"], nb, fields["fields"][-1]["params"]["dim"])
        res_ids = connect.insert(collection_name, entities)
        connect.flush([collection_name])
        res = connect.get_collection_info(collection_name)
        assert res['auto_id'] == True
        assert res['segment_row_count'] == get_segment_row_count


class TestInfoInvalid(object):
+26 −17
Original line number Diff line number Diff line
@@ -10,10 +10,12 @@ collection_id = "load_collection"
nb = 6000
default_fields = gen_default_fields() 
entities = gen_entities(nb)
field_name = "float_vector"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
raw_vectors, binary_entities = gen_binary_entities(nb)


class TestLoadCollection:
class TestLoadBase:

    """
    ******************************************************************
@@ -30,11 +32,22 @@ class TestLoadCollection:
                pytest.skip("sq8h not support in cpu mode")
        return request.param

    @pytest.fixture(
        scope="function",
        params=gen_binary_index()
    )
    def get_binary_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] in binary_support():
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    def test_load_collection_after_index(self, connect, collection, get_simple_index):
        '''
        target: test load collection, after index created
        method: insert and create index, load collection with correct params
        expected: describe raise exception
        expected: no error raised
        ''' 
        connect.insert(collection, entities)
        connect.flush([collection])
@@ -42,20 +55,20 @@ class TestLoadCollection:
        connect.create_index(collection, field_name, get_simple_index)
        connect.load_collection(collection)

    # TODO:
    @pytest.mark.level(1)
    def test_load_collection_after_index_binary(self, connect, binary_collection):
    @pytest.mark.level(2)
    def test_load_collection_after_index_binary(self, connect, binary_collection, get_binary_index):
        '''
        target: test load binary_collection, after index created
        method: insert and create index, load binary_collection with correct params
        expected: describe raise exception
        expected: no error raised
        ''' 
        # connect.insert(binary_collection, entities)
        # connect.flush([binary_collection])
        # logging.getLogger().info(get_simple_index)
        # connect.create_index(binary_collection, field_name, get_simple_index)
        # connect.load_collection(binary_collection)
        pass
        connect.insert(binary_collection, binary_entities)
        connect.flush([binary_collection])
        for metric_type in binary_metrics():
            logging.getLogger().info(metric_type)
            get_binary_index["metric_type"] = metric_type
            connect.create_index(binary_collection, binary_field_name, get_binary_index)
            connect.load_collection(binary_collection)

    def load_empty_collection(self, connect, collection):
        '''
@@ -86,10 +99,6 @@ class TestLoadCollection:
    def test_load_collection_after_search(self, connect, collection):
        pass

    @pytest.mark.level(2)
    def test_load_collection_before_search(self, connect, collection):
        pass


class TestLoadCollectionInvalid(object):
    """
Loading