Unverified Commit 9899cfef authored by Jin Hai's avatar Jin Hai Committed by GitHub
Browse files

#1475 add debug timing for map uids (#1513)

parent fc8a6e07
Loading
Loading
Loading
Loading
+23 −28
Original line number Diff line number Diff line
@@ -741,6 +741,17 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
    return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
}

// map offsets to ids
void
MapUids(const std::vector<segment::doc_id_t>& uids, int64_t* labels, size_t num) {
    for (int64_t i = 0; i < num; ++i) {
        int64_t& offset = labels[i];
        if (offset != -1) {
            offset = uids[offset];
        }
    }
}

Status
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
                            bool hybrid) {
@@ -798,7 +809,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
        }
    }
#endif
    TimeRecorder rc("ExecutionEngineImpl::Search");
    TimeRecorder rc("ExecutionEngineImpl::Search float");

    if (index_ == nullptr) {
        ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search";
@@ -824,15 +835,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
    rc.RecordSection("search done");

    // map offsets to ids
    const std::vector<segment::doc_id_t>& uids = index_->GetUids();
    for (int64_t i = 0; i < n * k; i++) {
        int64_t offset = labels[i];
        if (offset != -1) {
            labels[i] = uids[offset];
        }
    }
    MapUids(index_->GetUids(), labels, n * k);

    rc.RecordSection("map uids");
    rc.RecordSection("map uids " + std::to_string(n * k));

    if (hybrid) {
        HybridUnset();
@@ -847,7 +852,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
Status
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
                            int64_t* labels, bool hybrid) {
    TimeRecorder rc("ExecutionEngineImpl::Search");
    TimeRecorder rc("ExecutionEngineImpl::Search uint8");

    if (index_ == nullptr) {
        ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search";
@@ -873,15 +878,9 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
    rc.RecordSection("search done");

    // map offsets to ids
    const std::vector<segment::doc_id_t>& uids = index_->GetUids();
    for (int64_t i = 0; i < n * k; i++) {
        int64_t offset = labels[i];
        if (offset != -1) {
            labels[i] = uids[offset];
        }
    }
    MapUids(index_->GetUids(), labels, n * k);

    rc.RecordSection("map uids");
    rc.RecordSection("map uids " + std::to_string(n * k));

    if (hybrid) {
        HybridUnset();
@@ -896,7 +895,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
Status
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances,
                            int64_t* labels, bool hybrid) {
    TimeRecorder rc("ExecutionEngineImpl::Search");
    TimeRecorder rc("ExecutionEngineImpl::Search vector of ids");

    if (index_ == nullptr) {
        ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search";
@@ -961,16 +960,12 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
    auto status = Status::OK();
    if (!offsets.empty()) {
        status = index_->SearchById(offsets.size(), offsets.data(), distances, labels, conf);
        rc.RecordSection("search by id done");
        rc.RecordSection("search done");

        // map offsets to ids
        for (int64_t i = 0; i < offsets.size() * k; i++) {
            int64_t offset = labels[i];
            if (offset != -1) {
                labels[i] = uids[offset];
            }
        }
        rc.RecordSection("map uids");
        MapUids(uids, labels, offsets.size() * k);

        rc.RecordSection("map uids " + std::to_string(offsets.size() * k));
    }

    if (hybrid) {