Commit 8540b356 authored by 王翔宇's avatar 王翔宇
Browse files

SQ8H in GPU part2


Former-commit-id: 4c48987574ed24bf8a543d97520eb3a6b554fca5
parent e675158b
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -180,7 +180,7 @@ IVFSQHybrid::UnsetQuantizer() {
    ivf_index->quantizer = nullptr;
}

void
VectorIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
    auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
    if (quantizer_conf != nullptr) {
@@ -207,8 +207,10 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
        index_composition->mode = quantizer_conf->mode;  // only 2

        auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
        index_.reset(gpu_index);
        gpu_mode = 2;  // all in gpu
        std::shared_ptr<faiss::Index> new_idx;
        new_idx.reset(gpu_index);
        auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id_, res);
        return sq_idx;
    } else {
        KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
    }
+1 −2
Original line number Diff line number Diff line
@@ -60,8 +60,7 @@ class IVFSQHybrid : public GPUIVFSQ {
    void
    UnsetQuantizer();

    // todo(xiaojun): return void => VecIndex
    void
    VectorIndexPtr
    LoadData(const knowhere::QuantizerPtr& q, const Config& conf);

    IndexModelPtr
+2 −2
Original line number Diff line number Diff line
@@ -253,9 +253,9 @@ TEST_P(IVFTest, hybrid) {
        quantizer_conf->gpu_id = device_id;
        auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf);
        quantizer_conf->mode = 2;
        hybrid_2_idx->LoadData(q, quantizer_conf);
        auto gpu_idx = hybrid_2_idx->LoadData(q, quantizer_conf);

        auto result = hybrid_2_idx->Search(query_dataset, conf);
        auto result = gpu_idx->Search(query_dataset, conf);
        AssertAnns(result, nq, conf->k);
        PrintResult(result, nq, k);
    }
+10 −5
Original line number Diff line number Diff line
@@ -256,11 +256,14 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
        conf->gpu_id = device_id;

        if (quantizer) {
            std::cout << "cache hit" << std::endl;
            // cache hit
            conf->mode = 2;
            index_->SetQuantizer(quantizer->Data());
            index_->LoadData(quantizer->Data(), conf);
            auto new_index = index_->LoadData(quantizer->Data(), conf);
            index_ = new_index;
        } else {
            std::cout << "cache miss" << std::endl;
            // cache hit
            // cache miss
            if (index_ == nullptr) {
                ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to copy to gpu";
@@ -268,9 +271,9 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
            }
            conf->mode = 1;
            auto q = index_->LoadQuantizer(conf);
            index_->SetQuantizer(q);
            conf->mode = 2;
            index_->LoadData(q, conf);
            auto new_index = index_->LoadData(q, conf);
            index_ = new_index;

            // cache
            auto cached_quantizer = std::make_shared<CachedQuantizer>(q);
@@ -445,7 +448,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr

    auto status = index_->Search(n, data, distances, labels, conf);

    if (hybrid) {
        HybridUnset();
    }

    if (!status.ok()) {
        ENGINE_LOG_ERROR << "Search error";
+3 −6
Original line number Diff line number Diff line
@@ -315,24 +315,21 @@ IVFHybridIndex::UnsetQuantizer() {
    return Status::OK();
}

Status
VecIndexPtr
IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
            new_idx->LoadData(q, conf);
            return std::make_shared<IVFHybridIndex>(new_idx->LoadData(q, conf), type);
        } else {
            WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
            return Status(KNOWHERE_ERROR, "not support");
        }
    } catch (knowhere::KnowhereException& e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (std::exception& e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_ERROR, e.what());
    }
    return Status::OK();
    return nullptr;
}

}  // namespace engine
Loading