Loading cpp/src/scheduler/task/SearchTask.cpp +96 −28 Original line number Diff line number Diff line Loading @@ -33,8 +33,6 @@ namespace scheduler { static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000; static constexpr size_t PARALLEL_REDUCE_BATCH = 1000; std::mutex XSearchTask::merge_mutex_; // TODO(wxyu): remove unused code // bool // NeedParallelReduce(uint64_t nq, uint64_t topk) { Loading Loading @@ -211,7 +209,7 @@ XSearchTask::Execute() { // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); Loading @@ -220,7 +218,7 @@ XSearchTask::Execute() { // search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed } // step 5: notify to send result to client // step 4: notify to send result to client search_job->SearchDone(index_id_); } Loading @@ -230,36 +228,41 @@ XSearchTask::Execute() { index_engine_ = nullptr; } Status XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { scheduler::ResultSet result_buf; void XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { if (result.empty()) { result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); for (auto i = 0; i < nq; ++i) { auto& result_buf_i = result_buf[i]; result.resize(nq); } for (uint64_t i = 0; i < nq; i++) { scheduler::Id2DistVec result_buf; auto &result_i = result[i]; if (result[i].empty()) { result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); uint64_t input_k_multi_i = input_k * i; for (auto k = 0; k < input_k; ++k) { uint64_t idx = input_k_multi_i + k; auto& result_buf_item = result_buf_i[k]; auto &result_buf_item = result_buf[k]; result_buf_item.first = input_ids[idx]; result_buf_item.second = input_distance[idx]; } } } else { size_t tar_size = result[0].size(); size_t tar_size = result_i.size(); uint64_t output_k = std::min(topk, input_k + tar_size); result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); for (auto i = 0; i < nq; ++i) { result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); size_t buf_k = 0, src_k = 0, tar_k = 0; uint64_t src_idx; auto& result_i = result[i]; auto& result_buf_i = result_buf[i]; uint64_t input_k_multi_i = input_k * i; while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { src_idx = input_k_multi_i + src_k; auto& result_buf_item = result_buf_i[buf_k]; auto &result_buf_item = result_buf[buf_k]; auto &result_item = result_i[tar_k]; if ((ascending && input_distance[src_idx] < result_item.second) || (!ascending && input_distance[src_idx] > result_item.second)) { Loading @@ -273,11 +276,11 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector buf_k++; } if (buf_k < topk) { if (buf_k < output_k) { if (src_k < input_k) { while (buf_k < output_k && src_k < input_k) { src_idx = input_k_multi_i + src_k; auto& result_buf_item = result_buf_i[buf_k]; auto &result_buf_item = result_buf[buf_k]; result_buf_item.first = input_ids[src_idx]; result_buf_item.second = input_distance[src_idx]; src_k++; Loading @@ -285,18 +288,83 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector } } else { while (buf_k < output_k && tar_k < tar_size) { result_buf_i[buf_k] = result_i[tar_k]; result_buf[buf_k] = result_i[tar_k]; tar_k++; buf_k++; } } } } result_i.swap(result_buf); } } result.swap(result_buf); void XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { if (src_ids.empty() || src_distance.empty()) return; std::vector<int64_t> id_buf(nq*topk, -1); std::vector<float> dist_buf(nq*topk, 0.0); uint64_t output_k = std::min(topk, tar_input_k + src_input_k); uint64_t buf_k, src_k, tar_k; uint64_t src_idx, tar_idx, buf_idx; uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; for (uint64_t i = 0; i < nq; i++) { src_input_k_multi_i = src_input_k * i; tar_input_k_multi_i = tar_input_k * i; buf_k_multi_i = output_k * i; buf_k = src_k = tar_k = 0; while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { src_idx = src_input_k_multi_i + src_k; tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; } else { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; } buf_k++; } if (buf_k < output_k) { if (src_k < src_input_k) { while (buf_k < output_k && src_k < src_input_k) { src_idx = src_input_k_multi_i + src_k; id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; buf_k++; } } else { while (buf_k < output_k && tar_k < tar_input_k) { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; buf_k++; } } } } return Status::OK(); tar_ids.swap(id_buf); tar_distance.swap(dist_buf); tar_input_k = output_k; } } // namespace scheduler Loading cpp/src/scheduler/task/SearchTask.h +19 −5 Original line number Diff line number Diff line Loading @@ -38,9 +38,25 @@ class XSearchTask : public Task { Execute() override; public: static Status TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; Loading @@ -49,8 +65,6 @@ class XSearchTask : public Task { int index_type_ = 0; ExecutionEnginePtr index_engine_ = nullptr; bool metric_l2 = true; static std::mutex merge_mutex_; }; } // namespace scheduler Loading cpp/unittest/db/test_search.cpp +116 −48 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" #include "utils/ThreadPool.h" namespace { Loading Loading @@ -91,20 +92,17 @@ TEST(DBSearchTest, TOPK_TEST) { bool ascending; std::vector<int64_t> ids1, ids2; std::vector<float> dist1, dist2; milvus::scheduler::ResultSet result; milvus::Status status; ms::ResultSet result; /* test1, id1/dist1 valid, id2/dist2 empty */ ascending = true; BuildResult(NQ, TOP_K, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ Loading @@ -112,10 +110,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist1.clear(); result.clear(); BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ Loading @@ -123,10 +119,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist2.clear(); result.clear(); BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); ///////////////////////////////////////////////////////////////////////////////////////// Loading @@ -139,14 +133,12 @@ TEST(DBSearchTest, TOPK_TEST) { /* test1, id1/dist1 valid, id2/dist2 empty */ BuildResult(NQ, TOP_K, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ Loading @@ -154,10 +146,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist1.clear(); result.clear(); BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ Loading @@ -165,10 +155,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist2.clear(); result.clear(); BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); } Loading @@ -177,32 +165,112 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t top_k = 1000; int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; std::vector<std::vector<int64_t>> id_vec; std::vector<std::vector<float>> dist_vec; std::vector<uint64_t> k_vec; std::vector<int64_t> input_ids; std::vector<float> input_distance; milvus::scheduler::ResultSet final_result; milvus::Status status; ms::ResultSet final_result, final_result_2, final_result_3; double span, reduce_cost = 0.0; int32_t i, k, step; double reduce_cost = 0.0; milvus::TimeRecorder rc(""); for (int32_t i = 0; i < index_file_num; i++) { for (i = 0; i < index_file_num; i++) { BuildResult(nq, top_k, ascending, input_ids, input_distance); id_vec.push_back(input_ids); dist_vec.push_back(input_distance); k_vec.push_back(top_k); } rc.RecordSection("Method-1 result reduce start"); rc.RecordSection("do search for context: " + std::to_string(i)); // pick up topk result status = milvus::scheduler::XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result); ASSERT_TRUE(status.ok()); /* method-1 */ for (i = 0; i < index_file_num; i++) { ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result); ASSERT_EQ(final_result.size(), nq); } reduce_cost = rc.RecordSection("Method-1 result reduce done"); std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl; /* method-2 */ std::vector<std::vector<int64_t>> id_vec_2(id_vec); std::vector<std::vector<float>> dist_vec_2(dist_vec); std::vector<uint64_t> k_vec_2(k_vec); rc.RecordSection("Method-2 result reduce start"); for (step = 1; step < index_file_num; step *= 2) { for (i = 0; i+step < index_file_num; i += step*2) { ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], id_vec_2[i+step], dist_vec_2[i+step], k_vec_2[i+step], nq, top_k, ascending); } } ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2); ASSERT_EQ(final_result_2.size(), nq); reduce_cost = rc.RecordSection("Method-2 result reduce done"); std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl; for (i = 0; i < nq; i++) { ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); for (k = 0; k < final_result.size(); k++) { ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); } } /* method-3 parallel */ std::vector<std::vector<int64_t>> id_vec_3(id_vec); std::vector<std::vector<float>> dist_vec_3(dist_vec); std::vector<uint64_t> k_vec_3(k_vec); span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); reduce_cost += span; uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM); milvus::ThreadPool threadPool(max_thread_count); std::list<std::future<void>> threads_list; rc.RecordSection("Method-3 parallel result reduce start"); for (step = 1; step < index_file_num; step *= 2) { for (i = 0; i+step < index_file_num; i += step*2) { threads_list.push_back( threadPool.enqueue(ms::XSearchTask::MergeTopkArray, std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]), std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]), nq, top_k, ascending)); } while (threads_list.size() > 0) { int nready = 0; for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { auto &p = *it; std::chrono::milliseconds span(0); if (p.wait_for(span) == std::future_status::ready) { threads_list.erase(it++); ++nready; } else { ++it; } } if (nready == 0) { std::this_thread::yield(); } } } ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3); ASSERT_EQ(final_result_3.size(), nq); reduce_cost = rc.RecordSection("Method-3 parallel result reduce done"); std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl; for (i = 0; i < nq; i++) { ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); for (k = 0; k < final_result.size(); k++) { ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); } } std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl; } Loading
cpp/src/scheduler/task/SearchTask.cpp +96 −28 Original line number Diff line number Diff line Loading @@ -33,8 +33,6 @@ namespace scheduler { static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000; static constexpr size_t PARALLEL_REDUCE_BATCH = 1000; std::mutex XSearchTask::merge_mutex_; // TODO(wxyu): remove unused code // bool // NeedParallelReduce(uint64_t nq, uint64_t topk) { Loading Loading @@ -211,7 +209,7 @@ XSearchTask::Execute() { // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); Loading @@ -220,7 +218,7 @@ XSearchTask::Execute() { // search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed } // step 5: notify to send result to client // step 4: notify to send result to client search_job->SearchDone(index_id_); } Loading @@ -230,36 +228,41 @@ XSearchTask::Execute() { index_engine_ = nullptr; } Status XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { scheduler::ResultSet result_buf; void XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { if (result.empty()) { result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); for (auto i = 0; i < nq; ++i) { auto& result_buf_i = result_buf[i]; result.resize(nq); } for (uint64_t i = 0; i < nq; i++) { scheduler::Id2DistVec result_buf; auto &result_i = result[i]; if (result[i].empty()) { result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); uint64_t input_k_multi_i = input_k * i; for (auto k = 0; k < input_k; ++k) { uint64_t idx = input_k_multi_i + k; auto& result_buf_item = result_buf_i[k]; auto &result_buf_item = result_buf[k]; result_buf_item.first = input_ids[idx]; result_buf_item.second = input_distance[idx]; } } } else { size_t tar_size = result[0].size(); size_t tar_size = result_i.size(); uint64_t output_k = std::min(topk, input_k + tar_size); result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); for (auto i = 0; i < nq; ++i) { result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); size_t buf_k = 0, src_k = 0, tar_k = 0; uint64_t src_idx; auto& result_i = result[i]; auto& result_buf_i = result_buf[i]; uint64_t input_k_multi_i = input_k * i; while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { src_idx = input_k_multi_i + src_k; auto& result_buf_item = result_buf_i[buf_k]; auto &result_buf_item = result_buf[buf_k]; auto &result_item = result_i[tar_k]; if ((ascending && input_distance[src_idx] < result_item.second) || (!ascending && input_distance[src_idx] > result_item.second)) { Loading @@ -273,11 +276,11 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector buf_k++; } if (buf_k < topk) { if (buf_k < output_k) { if (src_k < input_k) { while (buf_k < output_k && src_k < input_k) { src_idx = input_k_multi_i + src_k; auto& result_buf_item = result_buf_i[buf_k]; auto &result_buf_item = result_buf[buf_k]; result_buf_item.first = input_ids[src_idx]; result_buf_item.second = input_distance[src_idx]; src_k++; Loading @@ -285,18 +288,83 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector } } else { while (buf_k < output_k && tar_k < tar_size) { result_buf_i[buf_k] = result_i[tar_k]; result_buf[buf_k] = result_i[tar_k]; tar_k++; buf_k++; } } } } result_i.swap(result_buf); } } result.swap(result_buf); void XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { if (src_ids.empty() || src_distance.empty()) return; std::vector<int64_t> id_buf(nq*topk, -1); std::vector<float> dist_buf(nq*topk, 0.0); uint64_t output_k = std::min(topk, tar_input_k + src_input_k); uint64_t buf_k, src_k, tar_k; uint64_t src_idx, tar_idx, buf_idx; uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; for (uint64_t i = 0; i < nq; i++) { src_input_k_multi_i = src_input_k * i; tar_input_k_multi_i = tar_input_k * i; buf_k_multi_i = output_k * i; buf_k = src_k = tar_k = 0; while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { src_idx = src_input_k_multi_i + src_k; tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; } else { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; } buf_k++; } if (buf_k < output_k) { if (src_k < src_input_k) { while (buf_k < output_k && src_k < src_input_k) { src_idx = src_input_k_multi_i + src_k; id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; buf_k++; } } else { while (buf_k < output_k && tar_k < tar_input_k) { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; buf_k++; } } } } return Status::OK(); tar_ids.swap(id_buf); tar_distance.swap(dist_buf); tar_input_k = output_k; } } // namespace scheduler Loading
cpp/src/scheduler/task/SearchTask.h +19 −5 Original line number Diff line number Diff line Loading @@ -38,9 +38,25 @@ class XSearchTask : public Task { Execute() override; public: static Status TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; Loading @@ -49,8 +65,6 @@ class XSearchTask : public Task { int index_type_ = 0; ExecutionEnginePtr index_engine_ = nullptr; bool metric_l2 = true; static std::mutex merge_mutex_; }; } // namespace scheduler Loading
cpp/unittest/db/test_search.cpp +116 −48 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" #include "utils/ThreadPool.h" namespace { Loading Loading @@ -91,20 +92,17 @@ TEST(DBSearchTest, TOPK_TEST) { bool ascending; std::vector<int64_t> ids1, ids2; std::vector<float> dist1, dist2; milvus::scheduler::ResultSet result; milvus::Status status; ms::ResultSet result; /* test1, id1/dist1 valid, id2/dist2 empty */ ascending = true; BuildResult(NQ, TOP_K, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ Loading @@ -112,10 +110,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist1.clear(); result.clear(); BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ Loading @@ -123,10 +119,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist2.clear(); result.clear(); BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); ///////////////////////////////////////////////////////////////////////////////////////// Loading @@ -139,14 +133,12 @@ TEST(DBSearchTest, TOPK_TEST) { /* test1, id1/dist1 valid, id2/dist2 empty */ BuildResult(NQ, TOP_K, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ Loading @@ -154,10 +146,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist1.clear(); result.clear(); BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ Loading @@ -165,10 +155,8 @@ TEST(DBSearchTest, TOPK_TEST) { dist2.clear(); result.clear(); BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); ASSERT_TRUE(status.ok()); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); } Loading @@ -177,32 +165,112 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t top_k = 1000; int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; std::vector<std::vector<int64_t>> id_vec; std::vector<std::vector<float>> dist_vec; std::vector<uint64_t> k_vec; std::vector<int64_t> input_ids; std::vector<float> input_distance; milvus::scheduler::ResultSet final_result; milvus::Status status; ms::ResultSet final_result, final_result_2, final_result_3; double span, reduce_cost = 0.0; int32_t i, k, step; double reduce_cost = 0.0; milvus::TimeRecorder rc(""); for (int32_t i = 0; i < index_file_num; i++) { for (i = 0; i < index_file_num; i++) { BuildResult(nq, top_k, ascending, input_ids, input_distance); id_vec.push_back(input_ids); dist_vec.push_back(input_distance); k_vec.push_back(top_k); } rc.RecordSection("Method-1 result reduce start"); rc.RecordSection("do search for context: " + std::to_string(i)); // pick up topk result status = milvus::scheduler::XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result); ASSERT_TRUE(status.ok()); /* method-1 */ for (i = 0; i < index_file_num; i++) { ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result); ASSERT_EQ(final_result.size(), nq); } reduce_cost = rc.RecordSection("Method-1 result reduce done"); std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl; /* method-2 */ std::vector<std::vector<int64_t>> id_vec_2(id_vec); std::vector<std::vector<float>> dist_vec_2(dist_vec); std::vector<uint64_t> k_vec_2(k_vec); rc.RecordSection("Method-2 result reduce start"); for (step = 1; step < index_file_num; step *= 2) { for (i = 0; i+step < index_file_num; i += step*2) { ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], id_vec_2[i+step], dist_vec_2[i+step], k_vec_2[i+step], nq, top_k, ascending); } } ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2); ASSERT_EQ(final_result_2.size(), nq); reduce_cost = rc.RecordSection("Method-2 result reduce done"); std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl; for (i = 0; i < nq; i++) { ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); for (k = 0; k < final_result.size(); k++) { ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); } } /* method-3 parallel */ std::vector<std::vector<int64_t>> id_vec_3(id_vec); std::vector<std::vector<float>> dist_vec_3(dist_vec); std::vector<uint64_t> k_vec_3(k_vec); span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); reduce_cost += span; uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM); milvus::ThreadPool threadPool(max_thread_count); std::list<std::future<void>> threads_list; rc.RecordSection("Method-3 parallel result reduce start"); for (step = 1; step < index_file_num; step *= 2) { for (i = 0; i+step < index_file_num; i += step*2) { threads_list.push_back( threadPool.enqueue(ms::XSearchTask::MergeTopkArray, std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]), std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]), nq, top_k, ascending)); } while (threads_list.size() > 0) { int nready = 0; for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { auto &p = *it; std::chrono::milliseconds span(0); if (p.wait_for(span) == std::future_status::ready) { threads_list.erase(it++); ++nready; } else { ++it; } } if (nready == 0) { std::this_thread::yield(); } } } ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3); ASSERT_EQ(final_result_3.size(), nq); reduce_cost = rc.RecordSection("Method-3 parallel result reduce done"); std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl; for (i = 0; i < nq; i++) { ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); for (k = 0; k < final_result.size(); k++) { ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); } } std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl; }