Commit 0bbe9d16 authored by 蔡宇东's avatar 蔡宇东
Browse files

MS-606 optimize reduce API, update unittest


Former-commit-id: b41935d10774bff17ec175e69d04b288d773cc9a
parent 818abfab
Loading
Loading
Loading
Loading
+20 −27
Original line number Diff line number Diff line
@@ -155,8 +155,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {

    size_t file_size = index_engine_->PhysicalSize();

    std::string info = "Load file id:" + std::to_string(file_->id_) +
                       " file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
    std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" +
                       std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
                       " bytes from location: " + file_->location_ + " totally cost";
    double span = rc.ElapseFromBegin(info);
    //    for (auto &context : search_contexts_) {
@@ -209,7 +209,8 @@ XSearchTask::Execute() {

            // step 3: pick up topk result
            auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
            XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
            XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
                                              search_job->GetResult());

            span = rc.RecordSection(hdr + ", reduce topk");
            //            search_job->AccumReduceCost(span);
@@ -229,12 +230,8 @@ XSearchTask::Execute() {
}

void
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
                                  const std::vector<float>& input_distance,
                                  uint64_t input_k,
                                  uint64_t nq,
                                  uint64_t topk,
                                  bool ascending,
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                                  uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending,
                                  scheduler::ResultSet& result) {
    if (result.empty()) {
        result.resize(nq);
@@ -301,16 +298,12 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
}

void
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids,
                            std::vector<float>& tar_distance,
                            uint64_t& tar_input_k,
                            const std::vector<int64_t>& src_ids,
                            const std::vector<float>& src_distance,
                            uint64_t src_input_k,
                            uint64_t nq,
                            uint64_t topk,
                            bool ascending) {
    if (src_ids.empty() || src_distance.empty()) return;
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
                            const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
                            uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
    if (src_ids.empty() || src_distance.empty()) {
        return;
    }

    std::vector<int64_t> id_buf(nq * topk, -1);
    std::vector<float> dist_buf(nq * topk, 0.0);
+5 −16
Original line number Diff line number Diff line
@@ -39,24 +39,13 @@ class XSearchTask : public Task {

 public:
    static void
    MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
                         const std::vector<float>& input_distance,
                         uint64_t input_k,
                         uint64_t nq,
                         uint64_t topk,
                         bool ascending,
                         scheduler::ResultSet& result);
    MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
                         uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);

    static void
    MergeTopkArray(std::vector<int64_t>& tar_ids,
                   std::vector<float>& tar_distance,
                   uint64_t& tar_input_k,
                   const std::vector<int64_t>& src_ids,
                   const std::vector<float>& src_distance,
                   uint64_t src_input_k,
                   uint64_t nq,
                   uint64_t topk,
                   bool ascending);
    MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
                   const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
                   uint64_t nq, uint64_t topk, bool ascending);

 public:
    TableFileSchemaPtr file_;
+188 −110
Original line number Diff line number Diff line
@@ -28,20 +28,44 @@ namespace {
namespace ms = milvus::scheduler;

void
BuildResult(uint64_t nq,
BuildResult(std::vector<int64_t>& output_ids,
            std::vector<float>& output_distance,
            uint64_t topk,
            bool ascending,
            std::vector<int64_t>& output_ids,
            std::vector<float>& output_distence) {
            uint64_t nq,
            bool ascending) {
    output_ids.clear();
    output_ids.resize(nq * topk);
    output_distence.clear();
    output_distence.resize(nq * topk);
    output_distance.clear();
    output_distance.resize(nq * topk);

    for (uint64_t i = 0; i < nq; i++) {
        for (uint64_t j = 0; j < topk; j++) {
            output_ids[i * topk + j] = (int64_t)(drand48() * 100000);
            output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
            output_distance[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
        }
    }
}

void
CopyResult(std::vector<int64_t>& output_ids,
           std::vector<float>& output_distance,
           uint64_t output_topk,
           std::vector<int64_t>& input_ids,
           std::vector<float>& input_distance,
           uint64_t input_topk,
           uint64_t nq) {
    ASSERT_TRUE(input_ids.size() >= nq * input_topk);
    ASSERT_TRUE(input_distance.size() >= nq * input_topk);
    ASSERT_TRUE(output_topk <= input_topk);
    output_ids.clear();
    output_ids.resize(nq * output_topk);
    output_distance.clear();
    output_distance.resize(nq * output_topk);

    for (uint64_t i = 0; i < nq; i++) {
        for (uint64_t j = 0; j < output_topk; j++) {
            output_ids[i * output_topk + j] = input_ids[i * input_topk + j];
            output_distance[i * output_topk + j] = input_distance[i * input_topk + j];
        }
    }
}
@@ -51,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
                const std::vector<float>& input_distance_1,
                const std::vector<int64_t>& input_ids_2,
                const std::vector<float>& input_distance_2,
                uint64_t nq,
                uint64_t topk,
                uint64_t nq,
                bool ascending,
                const milvus::scheduler::ResultSet& result) {
    ASSERT_EQ(result.size(), nq);
@@ -96,32 +120,32 @@ TEST(DBSearchTest, TOPK_TEST) {

    /* test1, id1/dist1 valid, id2/dist2 empty */
    ascending = true;
    BuildResult(NQ, TOP_K, ascending, ids1, dist1);
    BuildResult(ids1, dist1, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(NQ, TOP_K, ascending, ids2, dist2);
    BuildResult(ids2, dist2, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test3, id1/dist1 small topk */
    ids1.clear();
    dist1.clear();
    result.clear();
    BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
    BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
    ids2.clear();
    dist2.clear();
    result.clear();
    BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
    BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

/////////////////////////////////////////////////////////////////////////////////////////
    ascending = false;
@@ -132,74 +156,103 @@ TEST(DBSearchTest, TOPK_TEST) {
    result.clear();

    /* test1, id1/dist1 valid, id2/dist2 empty */
    BuildResult(NQ, TOP_K, ascending, ids1, dist1);
    BuildResult(ids1, dist1, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(NQ, TOP_K, ascending, ids2, dist2);
    BuildResult(ids2, dist2, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test3, id1/dist1 small topk */
    ids1.clear();
    dist1.clear();
    result.clear();
    BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
    BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
    ids2.clear();
    dist2.clear();
    result.clear();
    BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
    BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
}

TEST(DBSearchTest, REDUCE_PERF_TEST) {
    int32_t nq = 100;
    int32_t top_k = 1000;
    int32_t index_file_num = 478;   /* sift1B dataset, index files num */
    bool ascending = true;

    std::vector<int32_t> thread_vec = {4, 8, 11};
    std::vector<int32_t> nq_vec = {1, 10, 100, 1000};
    std::vector<int32_t> topk_vec = {1, 4, 16, 64, 256, 1024};
    int32_t NQ = nq_vec[nq_vec.size()-1];
    int32_t TOPK = topk_vec[topk_vec.size()-1];

    std::vector<std::vector<int64_t>> id_vec;
    std::vector<std::vector<float>> dist_vec;
    std::vector<uint64_t> k_vec;
    std::vector<int64_t> input_ids;
    std::vector<float> input_distance;
    ms::ResultSet final_result, final_result_2, final_result_3;

    int32_t i, k, step;
    double reduce_cost = 0.0;
    milvus::TimeRecorder rc("");

    /* generate testing data */
    for (i = 0; i < index_file_num; i++) {
        BuildResult(nq, top_k, ascending, input_ids, input_distance);
        BuildResult(input_ids, input_distance, TOPK, NQ, ascending);
        id_vec.push_back(input_ids);
        dist_vec.push_back(input_distance);
        k_vec.push_back(top_k);
    }

    rc.RecordSection("Method-1 result reduce start");
    for (int32_t max_thread_num : thread_vec) {
        milvus::ThreadPool threadPool(max_thread_num);
        std::list<std::future<void>> threads_list;

        for (int32_t nq : nq_vec) {
            for (int32_t top_k : topk_vec) {
                ms::ResultSet final_result, final_result_2, final_result_3;

                std::vector<std::vector<int64_t>> id_vec_1(index_file_num);
                std::vector<std::vector<float>> dist_vec_1(index_file_num);
                for (i = 0; i < index_file_num; i++) {
                    CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
                }

                std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " +
                                    std::to_string(nq) + " " + std::to_string(top_k);
                milvus::TimeRecorder rc1(str1);

                ///////////////////////////////////////////////////////////////////////////////////////
                /* method-1 */
                for (i = 0; i < index_file_num; i++) {
        ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result);
                    ms::XSearchTask::MergeTopkToResultSet(id_vec_1[i],
                                                          dist_vec_1[i],
                                                          top_k,
                                                          nq,
                                                          top_k,
                                                          ascending,
                                                          final_result);
                    ASSERT_EQ(final_result.size(), nq);
                }

    reduce_cost = rc.RecordSection("Method-1 result reduce done");
    std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
                rc1.RecordSection("reduce done");

                ///////////////////////////////////////////////////////////////////////////////////////
                /* method-2 */
    std::vector<std::vector<int64_t>> id_vec_2(id_vec);
    std::vector<std::vector<float>> dist_vec_2(dist_vec);
    std::vector<uint64_t> k_vec_2(k_vec);
                std::vector<std::vector<int64_t>> id_vec_2(index_file_num);
                std::vector<std::vector<float>> dist_vec_2(index_file_num);
                std::vector<uint64_t> k_vec_2(index_file_num);
                for (i = 0; i < index_file_num; i++) {
                    CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
                    k_vec_2[i] = top_k;
                }

    rc.RecordSection("Method-2 result reduce start");
                std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " +
                                    std::to_string(nq) + " " + std::to_string(top_k);
                milvus::TimeRecorder rc2(str2);

                for (step = 1; step < index_file_num; step *= 2) {
                    for (i = 0; i + step < index_file_num; i += step * 2) {
@@ -208,38 +261,55 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
                                                        nq, top_k, ascending);
                    }
                }
    ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2);
                ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0],
                                                      dist_vec_2[0],
                                                      k_vec_2[0],
                                                      nq,
                                                      top_k,
                                                      ascending,
                                                      final_result_2);
                ASSERT_EQ(final_result_2.size(), nq);

    reduce_cost = rc.RecordSection("Method-2 result reduce done");
    std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
                rc2.RecordSection("reduce done");

                for (i = 0; i < nq; i++) {
                    ASSERT_EQ(final_result[i].size(), final_result_2[i].size());
        for (k = 0; k < final_result.size(); k++) {
                    for (k = 0; k < final_result[i].size(); k++) {
                        if (final_result[i][k].first != final_result_2[i][k].first) {
                            std::cout << i << " " << k << std::endl;
                        }
                        ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first);
                        ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second);
                    }
                }

                ///////////////////////////////////////////////////////////////////////////////////////
                /* method-3 parallel */
    std::vector<std::vector<int64_t>> id_vec_3(id_vec);
    std::vector<std::vector<float>> dist_vec_3(dist_vec);
    std::vector<uint64_t> k_vec_3(k_vec);

    uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM);
    milvus::ThreadPool threadPool(max_thread_count);
    std::list<std::future<void>> threads_list;
                std::vector<std::vector<int64_t>> id_vec_3(index_file_num);
                std::vector<std::vector<float>> dist_vec_3(index_file_num);
                std::vector<uint64_t> k_vec_3(index_file_num);
                for (i = 0; i < index_file_num; i++) {
                    CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
                    k_vec_3[i] = top_k;
                }

    rc.RecordSection("Method-3 parallel result reduce start");
                std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " +
                                    std::to_string(nq) + " " + std::to_string(top_k);
                milvus::TimeRecorder rc3(str3);

                for (step = 1; step < index_file_num; step *= 2) {
                    for (i = 0; i + step < index_file_num; i += step * 2) {
                        threads_list.push_back(
                            threadPool.enqueue(ms::XSearchTask::MergeTopkArray,
                                   std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]),
                                   std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]),
                                   nq, top_k, ascending));
                                               std::ref(id_vec_3[i]),
                                               std::ref(dist_vec_3[i]),
                                               std::ref(k_vec_3[i]),
                                               std::ref(id_vec_3[i + step]),
                                               std::ref(dist_vec_3[i + step]),
                                               std::ref(k_vec_3[i + step]),
                                               nq,
                                               top_k,
                                               ascending));
                    }

                    while (threads_list.size() > 0) {
@@ -260,17 +330,25 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
                        }
                    }
                }
    ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3);
                ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0],
                                                      dist_vec_3[0],
                                                      k_vec_3[0],
                                                      nq,
                                                      top_k,
                                                      ascending,
                                                      final_result_3);
                ASSERT_EQ(final_result_3.size(), nq);

    reduce_cost = rc.RecordSection("Method-3 parallel result reduce done");
    std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
                rc3.RecordSection("reduce done");

                for (i = 0; i < nq; i++) {
                    ASSERT_EQ(final_result[i].size(), final_result_3[i].size());
        for (k = 0; k < final_result.size(); k++) {
                    for (k = 0; k < final_result[i].size(); k++) {
                        ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first);
                        ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second);
                    }
                }
            }
        }
    }
}