Commit 12747146 authored by jinhai's avatar jinhai
Browse files

Merge branch 'caiyd_reduce_parallel_0.5.0' into 'branch-0.5.0'

MS-606 optimize reduce, update unittest

See merge request megasearch/milvus!694

Former-commit-id: a5d433108cf08a46f31f3adea84beab008082d07
parents 077d8882 0bbe9d16
Loading
Loading
Loading
Loading
+90 −29
Original line number Diff line number Diff line
@@ -34,8 +34,6 @@ namespace scheduler {
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;

std::mutex XSearchTask::merge_mutex_;

// TODO(wxyu): remove unused code
// bool
// NeedParallelReduce(uint64_t nq, uint64_t topk) {
@@ -162,8 +160,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {

    size_t file_size = index_engine_->PhysicalSize();

    std::string info = "Load file id:" + std::to_string(file_->id_) +
                       " file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
    std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" +
                       std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
                       " bytes from location: " + file_->location_ + " totally cost";
    double span = rc.ElapseFromBegin(info);
    //    for (auto &context : search_contexts_) {
@@ -221,7 +219,8 @@ XSearchTask::Execute() {

            // step 3: pick up topk result
            auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
            XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
            XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
                                              search_job->GetResult());

            span = rc.RecordSection(hdr + ", reduce topk");
            //            search_job->AccumReduceCost(span);
@@ -230,7 +229,7 @@ XSearchTask::Execute() {
            //            search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
        }

        // step 5: notify to send result to client
        // step 4: notify to send result to client
        search_job->SearchDone(index_id_);
    }

@@ -240,36 +239,37 @@ XSearchTask::Execute() {
    index_engine_ = nullptr;
}

Status
XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                        uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) {
    scheduler::ResultSet result_buf;

void
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                                  uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending,
                                  scheduler::ResultSet& result) {
    if (result.empty()) {
        result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            auto& result_buf_i = result_buf[i];
        result.resize(nq);
    }

    for (uint64_t i = 0; i < nq; i++) {
        scheduler::Id2DistVec result_buf;
        auto& result_i = result[i];

        if (result[i].empty()) {
            result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0));
            uint64_t input_k_multi_i = input_k * i;
            for (auto k = 0; k < input_k; ++k) {
                uint64_t idx = input_k_multi_i + k;
                auto& result_buf_item = result_buf_i[k];
                auto& result_buf_item = result_buf[k];
                result_buf_item.first = input_ids[idx];
                result_buf_item.second = input_distance[idx];
            }
        }
        } else {
        size_t tar_size = result[0].size();
            size_t tar_size = result_i.size();
            uint64_t output_k = std::min(topk, input_k + tar_size);
        result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0));
            size_t buf_k = 0, src_k = 0, tar_k = 0;
            uint64_t src_idx;
            auto& result_i = result[i];
            auto& result_buf_i = result_buf[i];
            uint64_t input_k_multi_i = input_k * i;
            while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
                src_idx = input_k_multi_i + src_k;
                auto& result_buf_item = result_buf_i[buf_k];
                auto& result_buf_item = result_buf[buf_k];
                auto& result_item = result_i[tar_k];
                if ((ascending && input_distance[src_idx] < result_item.second) ||
                    (!ascending && input_distance[src_idx] > result_item.second)) {
@@ -283,11 +283,11 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
                buf_k++;
            }

            if (buf_k < topk) {
            if (buf_k < output_k) {
                if (src_k < input_k) {
                    while (buf_k < output_k && src_k < input_k) {
                        src_idx = input_k_multi_i + src_k;
                        auto& result_buf_item = result_buf_i[buf_k];
                        auto& result_buf_item = result_buf[buf_k];
                        result_buf_item.first = input_ids[src_idx];
                        result_buf_item.second = input_distance[src_idx];
                        src_k++;
@@ -295,18 +295,79 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
                    }
                } else {
                    while (buf_k < output_k && tar_k < tar_size) {
                        result_buf_i[buf_k] = result_i[tar_k];
                        result_buf[buf_k] = result_i[tar_k];
                        tar_k++;
                        buf_k++;
                    }
                }
            }
        }

        result_i.swap(result_buf);
    }
}

void
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
                            const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
                            uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
    if (src_ids.empty() || src_distance.empty()) {
        return;
    }

    std::vector<int64_t> id_buf(nq * topk, -1);
    std::vector<float> dist_buf(nq * topk, 0.0);

    uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
    uint64_t buf_k, src_k, tar_k;
    uint64_t src_idx, tar_idx, buf_idx;
    uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;

    for (uint64_t i = 0; i < nq; i++) {
        src_input_k_multi_i = src_input_k * i;
        tar_input_k_multi_i = tar_input_k * i;
        buf_k_multi_i = output_k * i;
        buf_k = src_k = tar_k = 0;
        while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) {
            src_idx = src_input_k_multi_i + src_k;
            tar_idx = tar_input_k_multi_i + tar_k;
            buf_idx = buf_k_multi_i + buf_k;
            if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) ||
                (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) {
                id_buf[buf_idx] = src_ids[src_idx];
                dist_buf[buf_idx] = src_distance[src_idx];
                src_k++;
            } else {
                id_buf[buf_idx] = tar_ids[tar_idx];
                dist_buf[buf_idx] = tar_distance[tar_idx];
                tar_k++;
            }
            buf_k++;
        }

    result.swap(result_buf);
        if (buf_k < output_k) {
            if (src_k < src_input_k) {
                while (buf_k < output_k && src_k < src_input_k) {
                    src_idx = src_input_k_multi_i + src_k;
                    id_buf[buf_idx] = src_ids[src_idx];
                    dist_buf[buf_idx] = src_distance[src_idx];
                    src_k++;
                    buf_k++;
                }
            } else {
                while (buf_k < output_k && tar_k < tar_input_k) {
                    id_buf[buf_idx] = tar_ids[tar_idx];
                    dist_buf[buf_idx] = tar_distance[tar_idx];
                    tar_k++;
                    buf_k++;
                }
            }
        }
    }

    return Status::OK();
    tar_ids.swap(id_buf);
    tar_distance.swap(dist_buf);
    tar_input_k = output_k;
}

}  // namespace scheduler
+8 −5
Original line number Diff line number Diff line
@@ -38,9 +38,14 @@ class XSearchTask : public Task {
    Execute() override;

 public:
    static Status
    TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k,
               uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
    static void
    MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                         uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);

    static void
    MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
                   const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
                   uint64_t nq, uint64_t topk, bool ascending);

 public:
    TableFileSchemaPtr file_;
@@ -49,8 +54,6 @@ class XSearchTask : public Task {
    int index_type_ = 0;
    ExecutionEnginePtr index_engine_ = nullptr;
    bool metric_l2 = true;

    static std::mutex merge_mutex_;
};

}  // namespace scheduler
+223 −77

File changed.

Preview size limit exceeded, changes collapsed.