Commit 818abfab authored by 蔡宇东's avatar 蔡宇东
Browse files

MS-606 support result reduce parallel


Former-commit-id: 337bb3ac7b48bc508aba2e93c5a978bda683ee55
parent 21840934
Loading
Loading
Loading
Loading
+96 −28
Original line number Diff line number Diff line
@@ -33,8 +33,6 @@ namespace scheduler {
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;

std::mutex XSearchTask::merge_mutex_;

// TODO(wxyu): remove unused code
// bool
// NeedParallelReduce(uint64_t nq, uint64_t topk) {
@@ -211,7 +209,7 @@ XSearchTask::Execute() {

            // step 3: pick up topk result
            auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
            XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
            XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());

            span = rc.RecordSection(hdr + ", reduce topk");
            //            search_job->AccumReduceCost(span);
@@ -220,7 +218,7 @@ XSearchTask::Execute() {
            //            search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
        }

        // step 5: notify to send result to client
        // step 4: notify to send result to client
        search_job->SearchDone(index_id_);
    }

@@ -230,36 +228,41 @@ XSearchTask::Execute() {
    index_engine_ = nullptr;
}

Status
XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                        uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) {
    scheduler::ResultSet result_buf;

void
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
                                  const std::vector<float>& input_distance,
                                  uint64_t input_k,
                                  uint64_t nq,
                                  uint64_t topk,
                                  bool ascending,
                                  scheduler::ResultSet& result) {
    if (result.empty()) {
        result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            auto& result_buf_i = result_buf[i];
        result.resize(nq);
    }

    for (uint64_t i = 0; i < nq; i++) {
        scheduler::Id2DistVec result_buf;
        auto &result_i = result[i];

        if (result[i].empty()) {
            result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0));
            uint64_t input_k_multi_i = input_k * i;
            for (auto k = 0; k < input_k; ++k) {
                uint64_t idx = input_k_multi_i + k;
                auto& result_buf_item = result_buf_i[k];
                auto &result_buf_item = result_buf[k];
                result_buf_item.first = input_ids[idx];
                result_buf_item.second = input_distance[idx];
            }
        }
        } else {
        size_t tar_size = result[0].size();
            size_t tar_size = result_i.size();
            uint64_t output_k = std::min(topk, input_k + tar_size);
        result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
        for (auto i = 0; i < nq; ++i) {
            result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0));
            size_t buf_k = 0, src_k = 0, tar_k = 0;
            uint64_t src_idx;
            auto& result_i = result[i];
            auto& result_buf_i = result_buf[i];
            uint64_t input_k_multi_i = input_k * i;
            while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
                src_idx = input_k_multi_i + src_k;
                auto& result_buf_item = result_buf_i[buf_k];
                auto &result_buf_item = result_buf[buf_k];
                auto &result_item = result_i[tar_k];
                if ((ascending && input_distance[src_idx] < result_item.second) ||
                    (!ascending && input_distance[src_idx] > result_item.second)) {
@@ -273,11 +276,11 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
                buf_k++;
            }

            if (buf_k < topk) {
            if (buf_k < output_k) {
                if (src_k < input_k) {
                    while (buf_k < output_k && src_k < input_k) {
                        src_idx = input_k_multi_i + src_k;
                        auto& result_buf_item = result_buf_i[buf_k];
                        auto &result_buf_item = result_buf[buf_k];
                        result_buf_item.first = input_ids[src_idx];
                        result_buf_item.second = input_distance[src_idx];
                        src_k++;
@@ -285,18 +288,83 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
                    }
                } else {
                    while (buf_k < output_k && tar_k < tar_size) {
                        result_buf_i[buf_k] = result_i[tar_k];
                        result_buf[buf_k] = result_i[tar_k];
                        tar_k++;
                        buf_k++;
                    }
                }
            }
        }

        result_i.swap(result_buf);
    }
}

    result.swap(result_buf);
void
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids,
                            std::vector<float>& tar_distance,
                            uint64_t& tar_input_k,
                            const std::vector<int64_t>& src_ids,
                            const std::vector<float>& src_distance,
                            uint64_t src_input_k,
                            uint64_t nq,
                            uint64_t topk,
                            bool ascending) {
    if (src_ids.empty() || src_distance.empty()) return;

    std::vector<int64_t> id_buf(nq*topk, -1);
    std::vector<float> dist_buf(nq*topk, 0.0);

    uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
    uint64_t buf_k, src_k, tar_k;
    uint64_t src_idx, tar_idx, buf_idx;
    uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;

    for (uint64_t i = 0; i < nq; i++) {
        src_input_k_multi_i = src_input_k * i;
        tar_input_k_multi_i = tar_input_k * i;
        buf_k_multi_i = output_k * i;
        buf_k = src_k = tar_k = 0;
        while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) {
            src_idx = src_input_k_multi_i + src_k;
            tar_idx = tar_input_k_multi_i + tar_k;
            buf_idx = buf_k_multi_i + buf_k;
            if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) ||
                (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) {
                id_buf[buf_idx] = src_ids[src_idx];
                dist_buf[buf_idx] = src_distance[src_idx];
                src_k++;
            } else {
                id_buf[buf_idx] = tar_ids[tar_idx];
                dist_buf[buf_idx] = tar_distance[tar_idx];
                tar_k++;
            }
            buf_k++;
        }

        if (buf_k < output_k) {
            if (src_k < src_input_k) {
                while (buf_k < output_k && src_k < src_input_k) {
                    src_idx = src_input_k_multi_i + src_k;
                    id_buf[buf_idx] = src_ids[src_idx];
                    dist_buf[buf_idx] = src_distance[src_idx];
                    src_k++;
                    buf_k++;
                }
            } else {
                while (buf_k < output_k && tar_k < tar_input_k) {
                    id_buf[buf_idx] = tar_ids[tar_idx];
                    dist_buf[buf_idx] = tar_distance[tar_idx];
                    tar_k++;
                    buf_k++;
                }
            }
        }
    }

    return Status::OK();
    tar_ids.swap(id_buf);
    tar_distance.swap(dist_buf);
    tar_input_k = output_k;
}

}  // namespace scheduler
+19 −5
Original line number Diff line number Diff line
@@ -38,9 +38,25 @@ class XSearchTask : public Task {
    Execute() override;

 public:
    static Status
    TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k,
               uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
    static void
    MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
                         const std::vector<float>& input_distance,
                         uint64_t input_k,
                         uint64_t nq,
                         uint64_t topk,
                         bool ascending,
                         scheduler::ResultSet& result);

    static void
    MergeTopkArray(std::vector<int64_t>& tar_ids,
                   std::vector<float>& tar_distance,
                   uint64_t& tar_input_k,
                   const std::vector<int64_t>& src_ids,
                   const std::vector<float>& src_distance,
                   uint64_t src_input_k,
                   uint64_t nq,
                   uint64_t topk,
                   bool ascending);

 public:
    TableFileSchemaPtr file_;
@@ -49,8 +65,6 @@ class XSearchTask : public Task {
    int index_type_ = 0;
    ExecutionEnginePtr index_engine_ = nullptr;
    bool metric_l2 = true;

    static std::mutex merge_mutex_;
};

}  // namespace scheduler
+116 −48
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@

#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"

namespace {

@@ -91,20 +92,17 @@ TEST(DBSearchTest, TOPK_TEST) {
    bool ascending;
    std::vector<int64_t> ids1, ids2;
    std::vector<float> dist1, dist2;
    milvus::scheduler::ResultSet result;
    milvus::Status status;
    ms::ResultSet result;

    /* test1, id1/dist1 valid, id2/dist2 empty */
    ascending = true;
    BuildResult(NQ, TOP_K, ascending, ids1, dist1);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(NQ, TOP_K, ascending, ids2, dist2);
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test3, id1/dist1 small topk */
@@ -112,10 +110,8 @@ TEST(DBSearchTest, TOPK_TEST) {
    dist1.clear();
    result.clear();
    BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
@@ -123,10 +119,8 @@ TEST(DBSearchTest, TOPK_TEST) {
    dist2.clear();
    result.clear();
    BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

/////////////////////////////////////////////////////////////////////////////////////////
@@ -139,14 +133,12 @@ TEST(DBSearchTest, TOPK_TEST) {

    /* test1, id1/dist1 valid, id2/dist2 empty */
    BuildResult(NQ, TOP_K, ascending, ids1, dist1);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(NQ, TOP_K, ascending, ids2, dist2);
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test3, id1/dist1 small topk */
@@ -154,10 +146,8 @@ TEST(DBSearchTest, TOPK_TEST) {
    dist1.clear();
    result.clear();
    BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
@@ -165,10 +155,8 @@ TEST(DBSearchTest, TOPK_TEST) {
    dist2.clear();
    result.clear();
    BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
    status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
    ASSERT_TRUE(status.ok());
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
}

@@ -177,32 +165,112 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
    int32_t top_k = 1000;
    int32_t index_file_num = 478;   /* sift1B dataset, index files num */
    bool ascending = true;
    std::vector<std::vector<int64_t>> id_vec;
    std::vector<std::vector<float>> dist_vec;
    std::vector<uint64_t> k_vec;
    std::vector<int64_t> input_ids;
    std::vector<float> input_distance;
    milvus::scheduler::ResultSet final_result;
    milvus::Status status;
    ms::ResultSet final_result, final_result_2, final_result_3;

    double span, reduce_cost = 0.0;
    int32_t i, k, step;
    double reduce_cost = 0.0;
    milvus::TimeRecorder rc("");

    for (int32_t i = 0; i < index_file_num; i++) {
    for (i = 0; i < index_file_num; i++) {
        BuildResult(nq, top_k, ascending, input_ids, input_distance);
        id_vec.push_back(input_ids);
        dist_vec.push_back(input_distance);
        k_vec.push_back(top_k);
    }

    rc.RecordSection("Method-1 result reduce start");

        rc.RecordSection("do search for context: " + std::to_string(i));

        // pick up topk result
        status = milvus::scheduler::XSearchTask::TopkResult(input_ids,
                                                            input_distance,
                                                            top_k,
                                                            nq,
                                                            top_k,
                                                            ascending,
                                                            final_result);
        ASSERT_TRUE(status.ok());
    /* method-1 */
    for (i = 0; i < index_file_num; i++) {
        ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result);
        ASSERT_EQ(final_result.size(), nq);
    }

    reduce_cost = rc.RecordSection("Method-1 result reduce done");
    std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl;

    /* method-2 */
    std::vector<std::vector<int64_t>> id_vec_2(id_vec);
    std::vector<std::vector<float>> dist_vec_2(dist_vec);
    std::vector<uint64_t> k_vec_2(k_vec);

    rc.RecordSection("Method-2 result reduce start");

    for (step = 1; step < index_file_num; step *= 2) {
        for (i = 0; i+step < index_file_num; i += step*2) {
            ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i],
                                            id_vec_2[i+step], dist_vec_2[i+step], k_vec_2[i+step],
                                            nq, top_k, ascending);
        }
    }
    ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2);
    ASSERT_EQ(final_result_2.size(), nq);

    reduce_cost = rc.RecordSection("Method-2 result reduce done");
    std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl;

    for (i = 0; i < nq; i++) {
        ASSERT_EQ(final_result[i].size(), final_result_2[i].size());
        for (k = 0; k < final_result.size(); k++) {
            ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first);
            ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second);
        }
    }

    /* method-3 parallel */
    std::vector<std::vector<int64_t>> id_vec_3(id_vec);
    std::vector<std::vector<float>> dist_vec_3(dist_vec);
    std::vector<uint64_t> k_vec_3(k_vec);

        span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
        reduce_cost += span;
    uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM);
    milvus::ThreadPool threadPool(max_thread_count);
    std::list<std::future<void>> threads_list;

    rc.RecordSection("Method-3 parallel result reduce start");

    for (step = 1; step < index_file_num; step *= 2) {
        for (i = 0; i+step < index_file_num; i += step*2) {
            threads_list.push_back(
                threadPool.enqueue(ms::XSearchTask::MergeTopkArray,
                                   std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]),
                                   std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]),
                                   nq, top_k, ascending));
        }

        while (threads_list.size() > 0) {
            int nready = 0;
            for (auto it = threads_list.begin(); it != threads_list.end(); it = it) {
                auto &p = *it;
                std::chrono::milliseconds span(0);
                if (p.wait_for(span) == std::future_status::ready) {
                    threads_list.erase(it++);
                    ++nready;
                } else {
                    ++it;
                }
            }

            if (nready == 0) {
                std::this_thread::yield();
            }
        }
    }
    ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3);
    ASSERT_EQ(final_result_3.size(), nq);

    reduce_cost = rc.RecordSection("Method-3 parallel result reduce done");
    std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl;

    for (i = 0; i < nq; i++) {
        ASSERT_EQ(final_result[i].size(), final_result_3[i].size());
        for (k = 0; k < final_result.size(); k++) {
            ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first);
            ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second);
        }
    }
    std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl;
}