Commit a9c5b7b3 authored by jinhai's avatar jinhai
Browse files

Merge branch 'Refactor_Knowhere' into 'branch-0.5.0'

MS-583 Change to Status from errorcode

See merge request megasearch/milvus!599

Former-commit-id: 7173cb0bcb9e6669c124e75696b5c271da94b036
parents 481f1e54 72ffa6c4
Loading
Loading
Loading
Loading
+12 −20
Original line number Diff line number Diff line
@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
}

Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) {
    auto ec = index_->Add(n, xdata, xids);
    if (ec != KNOWHERE_SUCCESS) {
        return Status(DB_ERROR, "Add error");
    }
    return Status::OK();
    auto status = index_->Add(n, xdata, xids);
    return status;
}

size_t ExecutionEngineImpl::Count() const {
@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}

Status ExecutionEngineImpl::Serialize() {
    auto ec = write_index(index_, location_);
    if (ec != KNOWHERE_SUCCESS) {
        return Status(DB_ERROR, "Serialize: write to disk error");
    }
    return Status::OK();
    auto status = write_index(index_, location_);
    return status;
}

Status ExecutionEngineImpl::Load(bool to_cache) {
@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
    }

    if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
        auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
        if (ec != KNOWHERE_SUCCESS) {
        auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
        if (!status.ok()) {
            ENGINE_LOG_ERROR << "Merge: Add Error";
            return Status(DB_ERROR, "Merge: Add Error");
        }
        return Status::OK();
        return status;
    } else {
        return Status(DB_ERROR, "file index type is not idmap");
    }
@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
    build_cfg["nlist"] = nlist_;
    AutoGenParams(to_index->GetType(), Count(), build_cfg);

    auto ec = to_index->BuildAll(Count(),
    auto status = to_index->BuildAll(Count(),
                                 from_index->GetRawVectors(),
                                 from_index->GetRawIds(),
                                 build_cfg);
    if (ec != KNOWHERE_SUCCESS) { throw Exception(DB_ERROR, "Build index error"); }
    if (!status.ok()) { throw Exception(DB_ERROR, status.message()); }

    return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
}
@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,

    ENGINE_LOG_DEBUG << "Search Params: [k]  " << k << " [nprobe] " << nprobe;
    auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
    auto ec = index_->Search(n, data, distances, labels, cfg);
    if (ec != KNOWHERE_SUCCESS) {
    auto status = index_->Search(n, data, distances, labels, cfg);
    if (!status.ok()) {
        ENGINE_LOG_ERROR << "Search error";
        return Status(DB_ERROR, "Search: Search Error");
    }
    return Status::OK();
    return status;
}

Status ExecutionEngineImpl::Cache() {
+6 −4
Original line number Diff line number Diff line
@@ -28,7 +28,8 @@ namespace engine {

constexpr int64_t M_BYTE = 1024 * 1024;

ErrorCode KnowhereResource::Initialize() {
Status
KnowhereResource::Initialize() {
    struct GpuResourceSetting {
        int64_t pinned_memory = 300*M_BYTE;
        int64_t temp_memory = 300*M_BYTE;
@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
                                                                iter->second.resource_num);
    }

    return KNOWHERE_SUCCESS;
    return Status::OK();
}

ErrorCode KnowhereResource::Finalize() {
Status
KnowhereResource::Finalize() {
    knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

}
+6 −3
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@

#pragma once

#include "utils/Error.h"
#include "utils/Status.h"

namespace zilliz {
namespace milvus {
@@ -26,8 +26,11 @@ namespace engine {

class KnowhereResource {
public:
    static ErrorCode Initialize();
    static ErrorCode Finalize();
    static Status
    Initialize();

    static Status
    Finalize();
};


+75 −58
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"

#include "vec_impl.h"
#include "data_transfer.h"

@@ -32,7 +31,8 @@ namespace engine {

using namespace zilliz::knowhere;

ErrorCode VecIndexImpl::BuildAll(const long &nb,
Status
VecIndexImpl::BuildAll(const long &nb,
                       const float *xb,
                       const long *ids,
                       const Config &cfg,
@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_UNEXPECTED_ERROR;
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_INVALID_ARGUMENT;
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_ERROR;
        return Status(KNOWHERE_ERROR, e.what());
    }
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

ErrorCode VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
Status
VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
    try {
        auto dataset = GenDatasetWithIds(nb, dim, xb, ids);

        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_UNEXPECTED_ERROR;
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_INVALID_ARGUMENT;
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_ERROR;
        return Status(KNOWHERE_ERROR, e.what());
    }
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
Status
VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
    try {
        auto k = cfg["k"].as<int>();
        auto dataset = GenDataset(nq, dim, xq);
@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon

    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_UNEXPECTED_ERROR;
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_INVALID_ARGUMENT;
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_ERROR;
        return Status(KNOWHERE_ERROR, e.what());
    }
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
    type = ConvertToCpuIndexType(type);
    return index_->Serialize();
}

ErrorCode VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
    index_->Load(index_binary);
    dim = Dimension();
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

int64_t VecIndexImpl::Dimension() {
int64_t
VecIndexImpl::Dimension() {
    return index_->Dimension();
}

int64_t VecIndexImpl::Count() {
int64_t
VecIndexImpl::Count() {
    return index_->Count();
}

IndexType VecIndexImpl::GetType() {
IndexType
VecIndexImpl::GetType() {
    return type;
}

VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
    // TODO(linxj): exception handle
    auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
    auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
    return new_index;
}

VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
    // TODO(linxj): exception handle
    auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
    auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
@@ -167,14 +176,16 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
    return new_index;
}

VecIndexPtr VecIndexImpl::Clone() {
VecIndexPtr
VecIndexImpl::Clone() {
    // TODO(linxj): exception handle
    auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
    clone_index->dim = dim;
    return clone_index;
}

int64_t VecIndexImpl::GetDeviceId() {
int64_t
VecIndexImpl::GetDeviceId() {
    if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)) {
        return device_idx->GetGpuDevice();
    }
@@ -182,17 +193,20 @@ int64_t VecIndexImpl::GetDeviceId() {
    return -1; // -1 == cpu
}

float *BFIndex::GetRawVectors() {
float *
BFIndex::GetRawVectors() {
    auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
    if (raw_index) { return raw_index->GetRawVectors(); }
    return nullptr;
}

int64_t *BFIndex::GetRawIds() {
int64_t *
BFIndex::GetRawIds() {
    return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}

ErrorCode BFIndex::Build(const Config &cfg) {
ErrorCode
BFIndex::Build(const Config &cfg) {
    try {
        dim = cfg["dim"].as<int>();
        std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
@@ -209,7 +223,8 @@ ErrorCode BFIndex::Build(const Config &cfg) {
    return KNOWHERE_SUCCESS;
}

ErrorCode BFIndex::BuildAll(const long &nb,
Status
BFIndex::BuildAll(const long &nb,
                  const float *xb,
                  const long *ids,
                  const Config &cfg,
@@ -223,19 +238,20 @@ ErrorCode BFIndex::BuildAll(const long &nb,
        index_->Add(dataset, cfg);
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_UNEXPECTED_ERROR;
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_INVALID_ARGUMENT;
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_ERROR;
        return Status(KNOWHERE_ERROR, e.what());
    }
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

// TODO(linxj): add lock here.
ErrorCode IVFMixIndex::BuildAll(const long &nb,
Status
IVFMixIndex::BuildAll(const long &nb,
                      const float *xb,
                      const long *ids,
                      const Config &cfg,
@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
            type = ConvertToCpuIndexType(type);
        } else {
            WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
            return KNOWHERE_ERROR;
            return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
        }
    } catch (KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_UNEXPECTED_ERROR;
        return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
    } catch (jsoncons::json_exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_INVALID_ARGUMENT;
        return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
    } catch (std::exception &e) {
        WRAPPER_LOG_ERROR << e.what();
        return KNOWHERE_ERROR;
        return Status(KNOWHERE_ERROR, e.what());
    }
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

ErrorCode IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
    //index_ = std::make_shared<IVF>();
    index_->Load(index_binary);
    dim = Dimension();
    return KNOWHERE_SUCCESS;
    return Status::OK();
}

}
+70 −34
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@
#pragma once

#include "knowhere/index/vector_index/VectorIndex.h"

#include "vec_index.h"


@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
 public:
    explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
        : index_(std::move(index)), type(type) {};
    ErrorCode BuildAll(const long &nb,

    Status
    BuildAll(const long &nb,
             const float *xb,
             const long *ids,
             const Config &cfg,
             const long &nt,
             const float *xt) override;
    VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
    VecIndexPtr CopyToCpu(const Config &cfg) override;
    IndexType GetType() override;
    int64_t Dimension() override;
    int64_t Count() override;
    ErrorCode Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
    zilliz::knowhere::BinarySet Serialize() override;
    ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;
    VecIndexPtr Clone() override;
    int64_t GetDeviceId() override;
    ErrorCode Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;

    VecIndexPtr
    CopyToGpu(const int64_t &device_id, const Config &cfg) override;

    VecIndexPtr
    CopyToCpu(const Config &cfg) override;

    IndexType
    GetType() override;

    int64_t
    Dimension() override;

    int64_t
    Count() override;

    Status
    Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;

    zilliz::knowhere::BinarySet
    Serialize() override;

    Status
    Load(const zilliz::knowhere::BinarySet &index_binary) override;

    VecIndexPtr
    Clone() override;

    int64_t
    GetDeviceId() override;

    Status
    Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;

 protected:
    int64_t dim = 0;

    IndexType type = IndexType::INVALID;

    std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
};

@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
    explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
        : VecIndexImpl(std::move(index), type) {};

    ErrorCode BuildAll(const long &nb,
    Status
    BuildAll(const long &nb,
             const float *xb,
             const long *ids,
             const Config &cfg,
             const long &nt,
             const float *xt) override;
    ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;

    Status
    Load(const zilliz::knowhere::BinarySet &index_binary) override;
};

class BFIndex : public VecIndexImpl {
 public:
    explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
                                                                                          IndexType::FAISS_IDMAP) {};
    ErrorCode Build(const Config& cfg);
    float *GetRawVectors();
    ErrorCode BuildAll(const long &nb,

    ErrorCode
    Build(const Config &cfg);

    float *
    GetRawVectors();

    Status
    BuildAll(const long &nb,
             const float *xb,
             const long *ids,
             const Config &cfg,
             const long &nt,
             const float *xt) override;
    int64_t *GetRawIds();

    int64_t *
    GetRawIds();
};

}
Loading