Unverified Commit 1676f7cf authored by shengjun.li's avatar shengjun.li Committed by GitHub
Browse files

fix GPU search (#2455)

parent 8878951b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub
-   \#2395 Fix large nq cudaMalloc error
-   \#2399 The nlist set by the user may not take effect
-   \#2403 MySQL max_idle_time is 10 by default
-   \#2450 The deleted vectors may be found on GPU

## Feature

+2 −5
Original line number Diff line number Diff line
@@ -75,10 +75,9 @@ pass1SelectLists(void** listIndices,
                         topQueryToCentroid,
                         opt);
    if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) {
      heap.add(distanceStart[i], start + i);
    } else {
      heap.add((1.0 / 0.0), start + i);
      heap.addThreadQ(distanceStart[i], start + i);
    }
    heap.checkThreadQ();
  }

  // Handle warp divergence separately
@@ -91,8 +90,6 @@ pass1SelectLists(void** listIndices,
                         opt);
    if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) {
      heap.addThreadQ(distanceStart[i], start + i);
    } else {
      heap.addThreadQ((1.0 / 0.0), start + i);
    }
  }

+3 −8
Original line number Diff line number Diff line
@@ -156,23 +156,18 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      v = Math<T>::add(centroidDistances[i],
                        productDistances[row][i]);
    } else {
      v = (T)(1.0 / 0.0);
      heap.addThreadQ(v, i);
    }
      
    heap.add(v, i);
    heap.checkThreadQ();
  }

  if (i < productDistances.getSize(1)) {
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      v = Math<T>::add(centroidDistances[i],
                        productDistances[row][i]);
    } else {
      v = (T)(1.0 / 0.0);
    }

      heap.addThreadQ(v, i);
    }
  }

  heap.reduce();
  for (int i = threadIdx.x; i < k; i += blockDim.x) {
+5 −11
Original line number Diff line number Diff line
@@ -146,10 +146,9 @@ __global__ void blockSelect(Tensor<K, 2, true> in,

  for (; i < limit; i += ThreadsPerBlock) {
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      heap.add(*inStart, (IndexType) i);
    } else {
      heap.add(-1.0, (IndexType) i);
      heap.addThreadQ(*inStart, (IndexType) i);
    }
    heap.checkThreadQ();

    inStart += ThreadsPerBlock;
  }
@@ -158,8 +157,6 @@ __global__ void blockSelect(Tensor<K, 2, true> in,
  if (i < in.getSize(1)) {
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      heap.addThreadQ(*inStart, (IndexType) i);
    } else {
      heap.addThreadQ(-1.0, (IndexType) i);
    }
  }

@@ -208,10 +205,9 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,

  for (; i < limit; i += ThreadsPerBlock) {
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      heap.add(*inKStart, *inVStart);
    } else {
      heap.add(-1.0, *inVStart);
      heap.addThreadQ(*inKStart, *inVStart);
    }
    heap.checkThreadQ();

    inKStart += ThreadsPerBlock;
    inVStart += ThreadsPerBlock;
@@ -221,8 +217,6 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
  if (i < inK.getSize(1)) {
    if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
      heap.addThreadQ(*inKStart, *inVStart);
    } else {
      heap.addThreadQ(-1.0, *inVStart);
    }
  }

+0 −9
Original line number Diff line number Diff line
@@ -283,15 +283,6 @@ XSearchTask::Execute() {

                {
                    std::unique_lock<std::mutex> lock(search_job->mutex());

                    if (search_job->GetResultIds().size() > spec_k) {
                        if (search_job->GetResultIds().front() == -1) {
                            // initialized results set
                            search_job->GetResultIds().resize(spec_k * nq);
                            search_job->GetResultDistances().resize(spec_k * nq);
                        }
                    }

                    search_job->vector_count() = nq;
                    XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
                                                      search_job->GetResultIds(), search_job->GetResultDistances());