Commit 39aed079 authored by xiaojun.lin's avatar xiaojun.lin
Browse files

update v2


Former-commit-id: 9daec92b461857062ea0d38fa61e093e65390d00
parent 3389bf2b
Loading
Loading
Loading
Loading
+48 −0
Original line number Diff line number Diff line
@@ -696,6 +696,54 @@ TEST_F(GPURESTEST, copyandsearch) {

    std::thread search_thread(search_func);
    std::thread load_thread(load_func);
    search_thread.join();
    load_thread.join();
    tc.RecordSection("Copy&search total");
}

TEST_F(GPURESTEST, TrainAndSearch) {
    index_type = "GPUIVFSQ";
    index_ = IndexFactory(index_type);

    auto conf = std::make_shared<knowhere::IVFSQCfg>();
    conf->nlist = 1638;
    conf->d = dim;
    conf->gpu_id = device_id;
    conf->metric_type = knowhere::METRICTYPE::L2;
    conf->k = k;
    conf->nbits = 8;
    conf->nprobe = 1;

    auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
    index_->set_preprocessor(preprocessor);
    auto model = index_->Train(base_dataset, conf);
    auto new_index = IndexFactory(index_type);
    new_index->set_index_model(model);
    new_index->Add(base_dataset, conf);
    auto cpu_idx = knowhere::cloner::CopyGpuToCpu(new_index, knowhere::Config());
    cpu_idx->Seal();
    auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config());

    constexpr int train_count = 1;
    constexpr int search_count = 5000;
    auto train_stage = [&] {
        for (int i = 0; i < train_count; ++i) {
            auto model = index_->Train(base_dataset, conf);
            auto test_idx = IndexFactory(index_type);
            test_idx->set_index_model(model);
            test_idx->Add(base_dataset, conf);
        }
    };
    auto search_stage = [&](knowhere::VectorIndexPtr& search_idx) {
        for (int i = 0; i < search_count; ++i) {
            auto result = search_idx->Search(query_dataset, conf);
            AssertAnns(result, nq, k);
        }
    };

    // TimeRecorder tc("record");
    // train_stage();
    // tc.RecordSection("train cost");
    // search_stage(search_idx);
    // tc.RecordSection("search cost");