Commit 18b0303f authored by JinHai-CN's avatar JinHai-CN
Browse files

Add one more interface: UnsetQuantizer


Former-commit-id: 34b6b4ac1f9b2841a8af6e5e4166e9ad45a9de1f
parent 2cce8978
Loading
Loading
Loading
Loading
+24 −7
Original line number Diff line number Diff line
@@ -27,7 +27,8 @@
namespace zilliz {
namespace knowhere {

IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
    auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
    if (build_cfg != nullptr) {
        build_cfg->CheckValid(); // throw exception
@@ -58,7 +59,8 @@ IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config
    }
}

VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) {
VectorIndexPtr
IVFSQHybrid::CopyGpuToCpu(const Config &config) {
    std::lock_guard<std::mutex> lk(mutex_);

    if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
@@ -74,7 +76,8 @@ VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) {
    }
}

VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
VectorIndexPtr
IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
    if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
        ResScope rs(res, device_id, false);
        faiss::gpu::GpuClonerOptions option;
@@ -95,11 +98,13 @@ VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config
    }
}

void IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
void
IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
    FaissBaseIndex::LoadImpl(index_binary); // load on cpu
}

void IVFSQHybrid::search_impl(int64_t n,
void
IVFSQHybrid::search_impl(int64_t n,
                              const float *data,
                              int64_t k,
                              float *distances,
@@ -112,7 +117,8 @@ void IVFSQHybrid::search_impl(int64_t n,
    }
}

QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) {
QuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config &conf) {
    auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
    if (quantizer_conf != nullptr) {
        quantizer_conf->CheckValid(); // throw exception
@@ -140,7 +146,8 @@ QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) {
    }
}

void IVFSQHybrid::SetQuantizer(QuantizerPtr q) {
void
IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
    auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
    if (ivf_quantizer == nullptr) {
        KNOWHERE_THROW_MSG("Quantizer type error");
@@ -158,5 +165,15 @@ void IVFSQHybrid::SetQuantizer(QuantizerPtr q) {
    }
}

void
IVFSQHybrid::UnsetQuantizer() {
    auto *ivf_index = dynamic_cast<faiss::IndexIVF *>(index_.get());
    if(ivf_index == nullptr) {
        KNOWHERE_THROW_MSG("Index type error");
    }

    ivf_index->quantizer = nullptr;
}

}
}
+4 −1
Original line number Diff line number Diff line
@@ -49,7 +49,10 @@ class IVFSQHybrid : public GPUIVFSQ {
    LoadQuantizer(const Config &conf);

    void
    SetQuantizer(QuantizerPtr q);
    SetQuantizer(const QuantizerPtr& q);

    void
    UnsetQuantizer();

    IndexModelPtr
    Train(const DatasetPtr &dataset, const Config &config) override;
+23 −2
Original line number Diff line number Diff line
@@ -277,7 +277,8 @@ IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
    return Status::OK();
}

knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) {
knowhere::QuantizerPtr
IVFHybridIndex::LoadQuantizer(const Config& conf) {
    // TODO(linxj): Hardcode here
    if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)){
        return new_idx->LoadQuantizer(conf);
@@ -286,7 +287,8 @@ knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) {
    }
}

Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) {
Status
IVFHybridIndex::SetQuantizer(const knowhere::QuantizerPtr& q) {
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
@@ -304,6 +306,25 @@ Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) {
    }
}

Status
IVFHybridIndex::UnsetQuantizer() {
    try {
        // TODO(linxj): Hardcode here
        if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
            new_idx->UnsetQuantizer();
        } else {
            WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
            return Status(KNOWHERE_ERROR, "not support");
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return Status(KNOWHERE_ERROR, e.what());
    }
}

} // namespace engine
} // namespace milvus
} // namespace zilliz
+7 −2
Original line number Diff line number Diff line
@@ -103,9 +103,14 @@ class IVFMixIndex : public VecIndexImpl {

class IVFHybridIndex : public IVFMixIndex {
 public:
    knowhere::QuantizerPtr LoadQuantizer(const Config& conf) override;
    knowhere::QuantizerPtr
    LoadQuantizer(const Config& conf) override;

    Status SetQuantizer(knowhere::QuantizerPtr q) override;
    Status
    SetQuantizer(const knowhere::QuantizerPtr& q) override;

    Status
    UnsetQuantizer() override;
};

class BFIndex : public VecIndexImpl {
+5 −2
Original line number Diff line number Diff line
@@ -105,11 +105,14 @@ class VecIndex {

    // TODO(linxj): refactor later
    virtual knowhere::QuantizerPtr
    LoadQuantizer(const Config& conf) {}
    LoadQuantizer(const Config& conf) { return Status::OK(); }

    // TODO(linxj): refactor later
    virtual Status
    SetQuantizer(knowhere::QuantizerPtr q) {}
    SetQuantizer(const knowhere::QuantizerPtr& q) { return Status::OK(); }

    virtual Status
    UnsetQuantizer() { return Status::OK(); }
};

extern Status