Loading core/src/scheduler/task/SearchTask.cpp +65 −65 Original line number Diff line number Diff line Loading @@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s } } void XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { if (src_ids.empty() || src_distance.empty()) { return; } uint64_t output_k = std::min(topk, tar_input_k + src_input_k); std::vector<int64_t> id_buf(nq * output_k, -1); std::vector<float> dist_buf(nq * output_k, 0.0); uint64_t buf_k, src_k, tar_k; uint64_t src_idx, tar_idx, buf_idx; uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; for (uint64_t i = 0; i < nq; i++) { src_input_k_multi_i = src_input_k * i; tar_input_k_multi_i = tar_input_k * i; buf_k_multi_i = output_k * i; buf_k = src_k = tar_k = 0; while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { src_idx = src_input_k_multi_i + src_k; tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; } else { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; } buf_k++; } if (buf_k < output_k) { if (src_k < src_input_k) { while (buf_k < output_k && src_k < src_input_k) { src_idx = src_input_k_multi_i + src_k; buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; buf_k++; } } else { while (buf_k < output_k && tar_k < tar_input_k) { tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; buf_k++; } } } } tar_ids.swap(id_buf); tar_distance.swap(dist_buf); tar_input_k = output_k; } //void //XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, // uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { // if (src_ids.empty() || src_distance.empty()) { // return; // } // // uint64_t output_k = std::min(topk, tar_input_k + src_input_k); // std::vector<int64_t> id_buf(nq * output_k, -1); // std::vector<float> dist_buf(nq * output_k, 0.0); // // uint64_t buf_k, src_k, tar_k; // uint64_t src_idx, tar_idx, buf_idx; // uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; // // for (uint64_t i = 0; i < nq; i++) { // src_input_k_multi_i = src_input_k * i; // tar_input_k_multi_i = tar_input_k * i; // buf_k_multi_i = output_k * i; // buf_k = src_k = tar_k = 0; // while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { // src_idx = src_input_k_multi_i + src_k; // tar_idx = tar_input_k_multi_i + tar_k; // buf_idx = buf_k_multi_i + buf_k; // if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || // (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { // id_buf[buf_idx] = src_ids[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx]; // src_k++; // } else { // id_buf[buf_idx] = tar_ids[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx]; // tar_k++; // } // buf_k++; // } // // if (buf_k < output_k) { // if (src_k < src_input_k) { // while (buf_k < output_k && src_k < src_input_k) { // src_idx = src_input_k_multi_i + src_k; // buf_idx = buf_k_multi_i + buf_k; // id_buf[buf_idx] = src_ids[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx]; // src_k++; // buf_k++; // } // } else { // while (buf_k < output_k && tar_k < tar_input_k) { // tar_idx = tar_input_k_multi_i + tar_k; // buf_idx = buf_k_multi_i + buf_k; // id_buf[buf_idx] = tar_ids[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx]; // tar_k++; // buf_k++; // } // } // } // } // // tar_ids.swap(id_buf); // tar_distance.swap(dist_buf); // tar_input_k = output_k; //} } // namespace scheduler } // namespace milvus core/src/scheduler/task/SearchTask.h +4 −4 Original line number Diff line number Diff line Loading @@ -42,10 +42,10 @@ class XSearchTask : public Task { MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending); // static void // MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, // uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; Loading core/unittest/db/test_search.cpp +194 −176 File changed.Preview size limit exceeded, changes collapsed. Show changes Loading
core/src/scheduler/task/SearchTask.cpp +65 −65 Original line number Diff line number Diff line Loading @@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s } } void XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { if (src_ids.empty() || src_distance.empty()) { return; } uint64_t output_k = std::min(topk, tar_input_k + src_input_k); std::vector<int64_t> id_buf(nq * output_k, -1); std::vector<float> dist_buf(nq * output_k, 0.0); uint64_t buf_k, src_k, tar_k; uint64_t src_idx, tar_idx, buf_idx; uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; for (uint64_t i = 0; i < nq; i++) { src_input_k_multi_i = src_input_k * i; tar_input_k_multi_i = tar_input_k * i; buf_k_multi_i = output_k * i; buf_k = src_k = tar_k = 0; while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { src_idx = src_input_k_multi_i + src_k; tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; } else { id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; } buf_k++; } if (buf_k < output_k) { if (src_k < src_input_k) { while (buf_k < output_k && src_k < src_input_k) { src_idx = src_input_k_multi_i + src_k; buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; buf_k++; } } else { while (buf_k < output_k && tar_k < tar_input_k) { tar_idx = tar_input_k_multi_i + tar_k; buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; buf_k++; } } } } tar_ids.swap(id_buf); tar_distance.swap(dist_buf); tar_input_k = output_k; } //void //XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, // uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { // if (src_ids.empty() || src_distance.empty()) { // return; // } // // uint64_t output_k = std::min(topk, tar_input_k + src_input_k); // std::vector<int64_t> id_buf(nq * output_k, -1); // std::vector<float> dist_buf(nq * output_k, 0.0); // // uint64_t buf_k, src_k, tar_k; // uint64_t src_idx, tar_idx, buf_idx; // uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; // // for (uint64_t i = 0; i < nq; i++) { // src_input_k_multi_i = src_input_k * i; // tar_input_k_multi_i = tar_input_k * i; // buf_k_multi_i = output_k * i; // buf_k = src_k = tar_k = 0; // while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { // src_idx = src_input_k_multi_i + src_k; // tar_idx = tar_input_k_multi_i + tar_k; // buf_idx = buf_k_multi_i + buf_k; // if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || // (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { // id_buf[buf_idx] = src_ids[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx]; // src_k++; // } else { // id_buf[buf_idx] = tar_ids[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx]; // tar_k++; // } // buf_k++; // } // // if (buf_k < output_k) { // if (src_k < src_input_k) { // while (buf_k < output_k && src_k < src_input_k) { // src_idx = src_input_k_multi_i + src_k; // buf_idx = buf_k_multi_i + buf_k; // id_buf[buf_idx] = src_ids[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx]; // src_k++; // buf_k++; // } // } else { // while (buf_k < output_k && tar_k < tar_input_k) { // tar_idx = tar_input_k_multi_i + tar_k; // buf_idx = buf_k_multi_i + buf_k; // id_buf[buf_idx] = tar_ids[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx]; // tar_k++; // buf_k++; // } // } // } // } // // tar_ids.swap(id_buf); // tar_distance.swap(dist_buf); // tar_input_k = output_k; //} } // namespace scheduler } // namespace milvus
core/src/scheduler/task/SearchTask.h +4 −4 Original line number Diff line number Diff line Loading @@ -42,10 +42,10 @@ class XSearchTask : public Task { MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); static void MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending); // static void // MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, // uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; Loading
core/unittest/db/test_search.cpp +194 −176 File changed.Preview size limit exceeded, changes collapsed. Show changes