Commit e728c3f1 authored by shengjun.li's avatar shengjun.li Committed by JinHai-CN
Browse files

fix MatchNlist (#2401)

parent 500d769e
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -5,9 +5,11 @@ Please mark all change in change log and use the issue from GitHub

## Bug
-   \#2378 Duplicate data after server restart
-   \#2399 The nlist set by the user may not take effect

## Feature
-   \#2363 Update branch version to 0.9.1
-   \#2381 Upgrade to faiss_1.6.3

## Improvement
-   \#2370 Clean compile warning
+11 −12
Original line number Diff line number Diff line
@@ -75,14 +75,13 @@ ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode m
}

int64_t
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
    static float TYPICAL_COUNT = 1000000.0;
    if (size <= TYPICAL_COUNT / per_nlist + 1) {
        // handle less row count, avoid nlist set to 0
        return 1;
    } else if (int(size / TYPICAL_COUNT) * nlist <= 0) {
        // calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
        return int(size / TYPICAL_COUNT * per_nlist);
MatchNlist(int64_t size, int64_t nlist) {
    const int64_t TYPICAL_COUNT = 1000000;
    const int64_t PER_NLIST = 16384;

    if (nlist * TYPICAL_COUNT > size * PER_NLIST) {
        // nlist is too large, adjust to a proper value
        nlist = std::max(1L, size * PER_NLIST / TYPICAL_COUNT);
    }
    return nlist;
}
@@ -101,7 +100,7 @@ IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
    // auto tune params
    int64_t nq = oricfg[knowhere::meta::ROWS].get<int64_t>();
    int64_t nlist = oricfg[knowhere::IndexParams::nlist].get<int64_t>();
    oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist, 16384);
    oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist);

    // Best Practice
    // static int64_t MIN_POINTS_PER_CENTROID = 40;
@@ -157,8 +156,8 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
    // CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);

    // auto tune params
    oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
                                                      oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
    oricfg[knowhere::IndexParams::nlist] =
        MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());

    // Best Practice
    // static int64_t MIN_POINTS_PER_CENTROID = 40;
@@ -216,7 +215,7 @@ NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
    CheckIntByRange(knowhere::IndexParams::candidate, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE);

    // auto tune params
    oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), 8192, 8192);
    oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), 8192);

    int64_t nprobe = int(oricfg[knowhere::IndexParams::nlist].get<int64_t>() * 0.1);
    oricfg[knowhere::IndexParams::nprobe] = nprobe < 1 ? 1 : nprobe;
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@

#pragma once

#include <algorithm>
#include <memory>
#include <vector>