Commit b43b9bac authored by 蔡宇东's avatar 蔡宇东
Browse files

#89 do normalize() for IP test


Former-commit-id: 11be41ee43e3dd0e2ecbcd50ba70bab2df58688d
parent 2f762900
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -54,6 +54,20 @@ elapsed() {
    return tv.tv_sec + tv.tv_usec * 1e-6;
}

void normalize(float* arr, size_t nq, size_t dim) {
    for (size_t i = 0; i < nq; i++) {
        double vecLen = 0.0;
        for (size_t j = 0; j < dim; j++) {
            double val = arr[i * dim + j];
            vecLen += val * val;
        }
        vecLen = std::sqrt(vecLen);
        for (size_t j = 0; j < dim; j++) {
            arr[i * dim + j] = (float) (arr[i * dim + j] / vecLen);
        }
    }
}

void*
hdf5_read(const char* file_name, const char* dataset_name, H5T_class_t dataset_class, size_t& d_out, size_t& n_out) {
    hid_t file, dataset, datatype, dataspace, memspace;
@@ -237,6 +251,11 @@ test_ann_hdf5(const std::string& ann_test_name, const std::string& index_key, in
        float* xb = (float*)hdf5_read(ann_file_name.c_str(), "train", H5T_FLOAT, d, nb);
        assert(d == dim || !"dataset does not have correct dimension");

        if (metric_type == faiss::METRIC_INNER_PRODUCT) {
            printf("[%.3f s] Normalizing data set \n", elapsed() - t0);
            normalize(xb, nb, d);
        }

        printf("[%.3f s] Preparing index \"%s\" d=%ld\n", elapsed() - t0, index_key.c_str(), d);

        index = faiss::index_factory(d, index_key.c_str(), metric_type);