Commit 6fcd2a13 authored by 余昆's avatar 余昆
Browse files

fix CPU version bug

parent fd304cf4
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
@@ -116,29 +116,29 @@ NSG::Train(const DatasetPtr& dataset, const Config& config) {
    }

    // TODO(linxj): dev IndexFactory, support more IndexType
    bool use_gpu = false;
    Graph knng;
#ifdef MILVUS_GPU_VERSION
    use_gpu = true;
    auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(build_cfg->gpu_id);
    if (temp_resource == nullptr)
        use_gpu = false;
#endif
    Graph knng;
    if (use_gpu) {
        auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);
    if (temp_resource == nullptr) {
        auto preprocess_index = std::make_shared<IVF>();
        auto model = preprocess_index->Train(dataset, config);
        preprocess_index->set_index_model(model);
        preprocess_index->AddWithoutIds(dataset, config);

        preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
    } else {
        auto preprocess_index = std::make_shared<IVF>();
        auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);
        auto model = preprocess_index->Train(dataset, config);
        preprocess_index->set_index_model(model);
        preprocess_index->AddWithoutIds(dataset, config);

        preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
    }
#else
    auto preprocess_index = std::make_shared<IVF>();
    auto model = preprocess_index->Train(dataset, config);
    preprocess_index->set_index_model(model);
    preprocess_index->AddWithoutIds(dataset, config);
    preprocess_index->GenGraph(build_cfg->knng, knng, dataset, config);
#endif

    algo::BuildParams b_params;
    b_params.candidate_pool_size = build_cfg->candidate_pool_size;