Commit b6660c4f authored by 蔡宇东's avatar 蔡宇东
Browse files

MS-606 fix result reduce bug


Former-commit-id: 3c7be785d9eeb979e5050016b2330bc5f5d4841a
parent 1be3bc51
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -315,10 +315,10 @@ XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& t
        return;
    }

    std::vector<int64_t> id_buf(nq * topk, -1);
    std::vector<float> dist_buf(nq * topk, 0.0);

    uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
    std::vector<int64_t> id_buf(nq * output_k, -1);
    std::vector<float> dist_buf(nq * output_k, 0.0);

    uint64_t buf_k, src_k, tar_k;
    uint64_t src_idx, tar_idx, buf_idx;
    uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;
@@ -349,6 +349,7 @@ XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& t
            if (src_k < src_input_k) {
                while (buf_k < output_k && src_k < src_input_k) {
                    src_idx = src_input_k_multi_i + src_k;
                    buf_idx = buf_k_multi_i + buf_k;
                    id_buf[buf_idx] = src_ids[src_idx];
                    dist_buf[buf_idx] = src_distance[src_idx];
                    src_k++;
@@ -356,6 +357,8 @@ XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& t
                }
            } else {
                while (buf_k < output_k && tar_k < tar_input_k) {
                    tar_idx = tar_input_k_multi_i + tar_k;
                    buf_idx = buf_k_multi_i + buf_k;
                    id_buf[buf_idx] = tar_ids[tar_idx];
                    dist_buf[buf_idx] = tar_distance[tar_idx];
                    tar_k++;
+71 −56
Original line number Diff line number Diff line
@@ -110,87 +110,102 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,

} // namespace

TEST(DBSearchTest, TOPK_TEST) {
    uint64_t NQ = 15;
    uint64_t TOP_K = 64;
    bool ascending;
void MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) {
    std::vector<int64_t> ids1, ids2;
    std::vector<float> dist1, dist2;
    ms::ResultSet result;
    BuildResult(ids1, dist1, topk_1, nq, ascending);
    BuildResult(ids2, dist2, topk_2, nq, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result);
}

TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
    uint64_t NQ = 15;
    uint64_t TOP_K = 64;

    /* test1, id1/dist1 valid, id2/dist2 empty */
    ascending = true;
    BuildResult(ids1, dist1, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, true);
    MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, false);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(ids2, dist2, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, true);
    MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, false);

    /* test3, id1/dist1 small topk */
    ids1.clear();
    dist1.clear();
    result.clear();
    BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, true);
    MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, false);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
    ids2.clear();
    dist2.clear();
    result.clear();
    BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);

/////////////////////////////////////////////////////////////////////////////////////////
    ascending = false;
    ids1.clear();
    dist1.clear();
    ids2.clear();
    dist2.clear();
    result.clear();
    MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true);
    MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false);
}

void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) {
    std::vector<int64_t> ids1, ids2;
    std::vector<float> dist1, dist2;
    ms::ResultSet result;
    BuildResult(ids1, dist1, topk_1, nq, ascending);
    BuildResult(ids2, dist2, topk_2, nq, ascending);
    uint64_t result_topk = std::min(topk, topk_1 + topk_2);
    ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending);
    if (ids1.size() != result_topk * nq) {
        std::cout << ids1.size() << " " << result_topk * nq << std::endl;
    }
    ASSERT_TRUE(ids1.size() == result_topk * nq);
    ASSERT_TRUE(dist1.size() == result_topk * nq);
    for (uint64_t i = 0; i < nq; i++) {
        for (uint64_t k = 1; k < result_topk; k++) {
            if (ascending) {
                if (dist1[i * result_topk + k] < dist1[i * result_topk + k - 1]) {
                    std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl;
                }
                ASSERT_TRUE(dist1[i * result_topk + k] >= dist1[i * result_topk + k - 1]);
            } else {
                if (dist1[i * result_topk + k] > dist1[i * result_topk + k - 1]) {
                    std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl;
                }
                ASSERT_TRUE(dist1[i * result_topk + k] <= dist1[i * result_topk + k - 1]);
            }
        }
    }
}

TEST(DBSearchTest, MERGE_ARRAY_TEST) {
    uint64_t NQ = 15;
    uint64_t TOP_K = 64;

    /* test1, id1/dist1 valid, id2/dist2 empty */
    BuildResult(ids1, dist1, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, false);
    MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, true);
    MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, false);

    /* test2, id1/dist1 valid, id2/dist2 valid */
    BuildResult(ids2, dist2, TOP_K, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, false);

    /* test3, id1/dist1 small topk */
    ids1.clear();
    dist1.clear();
    result.clear();
    BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, false);
    MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, false);

    /* test4, id1/dist1 small topk, id2/dist2 small topk */
    ids2.clear();
    dist2.clear();
    result.clear();
    BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
    ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
    ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
    CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
    MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false);
    MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, true);
    MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, false);
}

TEST(DBSearchTest, REDUCE_PERF_TEST) {
    int32_t index_file_num = 478;   /* sift1B dataset, index files num */
    bool ascending = true;

    std::vector<int32_t> thread_vec = {4, 8, 11};
    std::vector<int32_t> nq_vec = {1, 10, 100, 1000};
    std::vector<int32_t> topk_vec = {1, 4, 16, 64, 256, 1024};
    std::vector<int32_t> thread_vec = {4, 8};
    std::vector<int32_t> nq_vec = {1, 10, 100};
    std::vector<int32_t> topk_vec = {1, 4, 16, 64};
    int32_t NQ = nq_vec[nq_vec.size()-1];
    int32_t TOPK = topk_vec[topk_vec.size()-1];