Commit 6d1e1578 authored by 蔡宇东's avatar 蔡宇东
Browse files

MS-606 speed up result reduce


Former-commit-id: 3414caf6afa687d79637890dd0f34c4d6c6dcd03
parent e884fdb0
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -14,15 +14,16 @@ Please mark all change in change log and use the ticket from JIRA.
## Improvement
- MS-552 - Add and change the easylogging library
- MS-553 - Refine cache code
- MS-557 - Merge Log.h
- MS-555 - Remove old scheduler
- MS-556 - Add Job Definition in Scheduler
- MS-557 - Merge Log.h
- MS-558 - Refine status code
- MS-562 - Add JobMgr and TaskCreator in Scheduler
- MS-566 - Refactor cmake
- MS-555 - Remove old scheduler
- MS-574 - Milvus configuration refactor
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
- MS-585 - Update namespace in scheduler
- MS-606 - Speed up result reduce
- MS-608 - Update TODO names
- MS-609 - Update task construct function

+3 −2
Original line number Diff line number Diff line
@@ -37,8 +37,9 @@ namespace scheduler {
using engine::meta::TableFileSchemaPtr;

using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2DistanceMap>;
using IdDistPair = std::pair<int64_t, double>;
using Id2DistVec = std::vector<IdDistPair>;
using ResultSet = std::vector<Id2DistVec>;

class SearchJob : public Job {
 public:
+70 −143
Original line number Diff line number Diff line
@@ -78,18 +78,19 @@ std::mutex XSearchTask::merge_mutex_;

void
CollectFileMetrics(int file_type, size_t file_size) {
    server::MetricsBase& inst = server::Metrics::GetInstance();
    switch (file_type) {
        case TableFileSchema::RAW:
        case TableFileSchema::TO_INDEX: {
            server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
            server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
            server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
            inst.RawFileSizeHistogramObserve(file_size);
            inst.RawFileSizeTotalIncrement(file_size);
            inst.RawFileSizeGaugeSet(file_size);
            break;
        }
        default: {
            server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
            server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
            server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
            inst.IndexFileSizeHistogramObserve(file_size);
            inst.IndexFileSizeTotalIncrement(file_size);
            inst.IndexFileSizeGaugeSet(file_size);
            break;
        }
    }
@@ -206,16 +207,9 @@ XSearchTask::Execute() {
            double span = rc.RecordSection(hdr + ", do search");
            //            search_job->AccumSearchCost(span);

            // step 3: cluster result
            scheduler::ResultSet result_set;
            // step 3: pick up topk result
            auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
            XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set);

            span = rc.RecordSection(hdr + ", cluster result");
            //            search_job->AccumReduceCost(span);

            // step 4: pick up topk result
            XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult());
            XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());

            span = rc.RecordSection(hdr + ", reduce topk");
            //            search_job->AccumReduceCost(span);
@@ -235,142 +229,75 @@ XSearchTask::Execute() {
}

Status
XSearchTask::ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distance,
                           uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) {
    if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) {
        std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " +
                          std::to_string(output_distance.size());
        ENGINE_LOG_ERROR << msg;
        return Status(DB_ERROR, msg);
    }

    result_set.clear();
    result_set.resize(nq);

    std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
        for (auto i = from_index; i < to_index; i++) {
            scheduler::Id2DistanceMap id_distance;
            id_distance.reserve(topk);
            for (auto k = 0; k < topk; k++) {
                uint64_t index = i * topk + k;
                if (output_ids[index] < 0) {
                    continue;
                }
                id_distance.push_back(std::make_pair(output_ids[index], output_distance[index]));
            }
            result_set[i] = id_distance;
        }
    };

    //    if (NeedParallelReduce(nq, topk)) {
    //        ParallelReduce(reduce_worker, nq);
    //    } else {
    reduce_worker(0, nq);
    //    }

    return Status::OK();
}

Status
XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target,
                         uint64_t topk, bool ascending) {
    // Note: the score_src and score_target are already arranged by score in ascending order
    if (distance_src.empty()) {
        ENGINE_LOG_WARNING << "Empty distance source array";
        return Status::OK();
    }

    std::unique_lock<std::mutex> lock(merge_mutex_);
    if (distance_target.empty()) {
        distance_target.swap(distance_src);
        return Status::OK();
    }

    size_t src_count = distance_src.size();
    size_t target_count = distance_target.size();
    scheduler::Id2DistanceMap distance_merged;
    distance_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while (true) {
        // all score_src items are merged, if score_merged.size() still less than topk
        // move items from score_target to score_merged until score_merged.size() equal topk
        if (src_index >= src_count) {
            for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_target[i]);
            }
            break;
        }

        // all score_target items are merged, if score_merged.size() still less than topk
        // move items from score_src to score_merged until score_merged.size() equal topk
        if (target_index >= target_count) {
            for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_src[i]);
XSearchTask::TopkResult(const std::vector<long> &input_ids,
                        const std::vector<float> &input_distance,
                        uint64_t input_k,
                        uint64_t nq,
                        uint64_t topk,
                        bool ascending,
                        scheduler::ResultSet &result) {
    scheduler::ResultSet result_buf;

    if (result.empty()) {
        result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            auto& result_buf_i = result_buf[i];
            uint64_t input_k_multi_i = input_k * i;
            for (auto k = 0; k < input_k; ++k) {
                uint64_t idx = input_k_multi_i + k;
                auto& result_buf_item = result_buf_i[k];
                result_buf_item.first = input_ids[idx];
                result_buf_item.second = input_distance[idx];
            }
            break;
        }

        // compare score,
        // if ascending = true, put smallest score to score_merged one by one
        // else, put largest score to score_merged one by one
        auto& src_pair = distance_src[src_index];
        auto& target_pair = distance_target[target_index];
        if (ascending) {
            if (src_pair.second > target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
    } else {
                distance_merged.push_back(src_pair);
                src_index++;
            }
        size_t tar_size = result[0].size();
        uint64_t output_k = std::min(topk, input_k + tar_size);
        result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            size_t buf_k = 0, src_k = 0, tar_k = 0;
            uint64_t src_idx;
            auto& result_i = result[i];
            auto& result_buf_i = result_buf[i];
            uint64_t input_k_multi_i = input_k * i;
            while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
                src_idx = input_k_multi_i + src_k;
                auto& result_buf_item = result_buf_i[buf_k];
                auto& result_item = result_i[tar_k];
                if ((ascending && input_distance[src_idx] < result_item.second) ||
                   (!ascending && input_distance[src_idx] > result_item.second)) {
                    result_buf_item.first = input_ids[src_idx];
                    result_buf_item.second = input_distance[src_idx];
                    src_k++;
                } else {
            if (src_pair.second < target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
            } else {
                distance_merged.push_back(src_pair);
                src_index++;
                    result_buf_item = result_item;
                    tar_k++;
                }
                buf_k++;
            }

        // score_merged.size() already equal topk
        if (distance_merged.size() >= topk) {
            break;
            if (buf_k < topk) {
                if (src_k < input_k) {
                    while (buf_k < output_k && src_k < input_k) {
                        src_idx = input_k_multi_i + src_k;
                        auto& result_buf_item = result_buf_i[buf_k];
                        result_buf_item.first = input_ids[src_idx];
                        result_buf_item.second = input_distance[src_idx];
                        src_k++;
                        buf_k++;
                    }
                } else {
                    while (buf_k < output_k && tar_k < tar_size) {
                        result_buf_i[buf_k] = result_i[tar_k];
                        tar_k++;
                        buf_k++;
                    }

    distance_target.swap(distance_merged);

    return Status::OK();
                }

Status
XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending,
                        scheduler::ResultSet& result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return Status::OK();
            }

    if (result_src.size() != result_target.size()) {
        std::string msg = "Invalid result set size";
        ENGINE_LOG_ERROR << msg;
        return Status(DB_ERROR, msg);
        }

    std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
        for (size_t i = from_index; i < to_index; i++) {
            scheduler::Id2DistanceMap& score_src = result_src[i];
            scheduler::Id2DistanceMap& score_target = result_target[i];
            XSearchTask::MergeResult(score_src, score_target, topk, ascending);
    }
    };

    //    if (NeedParallelReduce(result_src.size(), topk)) {
    //        ParallelReduce(ReduceWorker, result_src.size());
    //    } else {
    ReduceWorker(0, result_src.size());
    //    }
    result.swap(result_buf);

    return Status::OK();
}
+7 −9
Original line number Diff line number Diff line
@@ -39,15 +39,13 @@ class XSearchTask : public Task {

 public:
    static Status
    ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distence, uint64_t nq,
                  uint64_t topk, scheduler::ResultSet& result_set);

    static Status
    MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk,
                bool ascending);

    static Status
    TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target);
    TopkResult(const std::vector<long> &input_ids,
               const std::vector<float> &input_distance,
               uint64_t input_k,
               uint64_t nq,
               uint64_t topk,
               bool ascending,
               scheduler::ResultSet &result);

 public:
    TableFileSchemaPtr file_;
+121 −226

File changed.

Preview size limit exceeded, changes collapsed.