mirror of https://github.com/milvus-io/milvus.git
Add Unittest for IDMAP (#1774)
* gpu idmap unittest Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * fix clang format Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * fix build issue Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * small fix Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * fix case Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * fix back Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * Compact code Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * efficient code Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * Fix GPU search hang Signed-off-by: sahuang <xiaohai.xu@zilliz.com>pull/1803/head
parent
4574f9d998
commit
8a731dea39
|
@ -105,7 +105,7 @@ GPUIDMAP::GetRawIds() {
|
|||
void
|
||||
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
index_->search(n, (float*)data, k, distances, labels);
|
||||
index_->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -310,7 +310,7 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k, float *distances, idx_t
|
|||
invlists->prefetch_lists (idx.get(), n * nprobe);
|
||||
|
||||
search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
|
||||
distances, labels, false, nullptr,bitset);
|
||||
distances, labels, false, nullptr, bitset);
|
||||
indexIVF_stats.search_time += getmillisecs() - t0;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,55 +18,6 @@ class GpuResources;
|
|||
|
||||
/// Calculates brute-force L2 distance between `vectors` and
|
||||
/// `queries`, returning the k closest results seen
|
||||
void runL2Distance(GpuResources* resources,
|
||||
Tensor<float, 2, true>& vectors,
|
||||
bool vectorsRowMajor,
|
||||
// can be optionally pre-computed; nullptr if we
|
||||
// have to compute it upon the call
|
||||
Tensor<float, 1, true>* vectorNorms,
|
||||
Tensor<float, 2, true>& queries,
|
||||
bool queriesRowMajor,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices,
|
||||
// Do we care about `outDistances`? If not, we can
|
||||
// take shortcuts.
|
||||
bool ignoreOutDistances = false);
|
||||
|
||||
/// Calculates brute-force inner product distance between `vectors`
|
||||
/// and `queries`, returning the k closest results seen
|
||||
void runIPDistance(GpuResources* resources,
|
||||
Tensor<float, 2, true>& vectors,
|
||||
bool vectorsRowMajor,
|
||||
Tensor<float, 2, true>& queries,
|
||||
bool queriesRowMajor,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices);
|
||||
|
||||
void runIPDistance(GpuResources* resources,
|
||||
Tensor<half, 2, true>& vectors,
|
||||
bool vectorsRowMajor,
|
||||
Tensor<half, 2, true>& queries,
|
||||
bool queriesRowMajor,
|
||||
int k,
|
||||
Tensor<half, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices,
|
||||
bool useHgemm);
|
||||
|
||||
void runL2Distance(GpuResources* resources,
|
||||
Tensor<half, 2, true>& vectors,
|
||||
bool vectorsRowMajor,
|
||||
Tensor<half, 1, true>* vectorNorms,
|
||||
Tensor<half, 2, true>& queries,
|
||||
bool queriesRowMajor,
|
||||
int k,
|
||||
Tensor<half, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices,
|
||||
bool useHgemm,
|
||||
bool ignoreOutDistances = false);
|
||||
|
||||
// Bitset added
|
||||
void runL2Distance(GpuResources* resources,
|
||||
Tensor<float, 2, true>& vectors,
|
||||
bool vectorsRowMajor,
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace faiss { namespace gpu {
|
|||
template <typename T, int kRowsPerBlock, int kBlockSize>
|
||||
__global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 1, true> centroidDistances,
|
||||
Tensor<uint8_t, 1, true> bitset,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int, 2, true> outIndices) {
|
||||
// Each block handles kRowsPerBlock rows of the distances (results)
|
||||
|
@ -44,6 +45,8 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
|
|||
// FIXME: if we have exact multiples, don't need this
|
||||
bool endRow = (blockIdx.x == gridDim.x - 1);
|
||||
|
||||
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
|
||||
|
||||
if (endRow) {
|
||||
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
|
||||
endRow = false;
|
||||
|
@ -54,8 +57,12 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
|
|||
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
distance[0] = Math<T>::add(centroidDistances[col],
|
||||
productDistances[row][col]);
|
||||
if (bitsetIsEmpty || (!(bitset[col >> 3] & (0x1 << (col & 0x7))))) {
|
||||
distance[0] = Math<T>::add(centroidDistances[col],
|
||||
productDistances[row][col]);
|
||||
} else {
|
||||
distance[0] = (T)(1.0 / 0.0);
|
||||
}
|
||||
|
||||
if (Math<T>::lt(distance[0], threadMin[0].k)) {
|
||||
threadMin[0].k = distance[0];
|
||||
|
@ -117,10 +124,12 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
|
|||
}
|
||||
}
|
||||
|
||||
// With bitset included
|
||||
// L2 + select kernel for k > 1, no re-use of ||c||^2
|
||||
template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
|
||||
__global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 1, true> centroidDistances,
|
||||
Tensor<uint8_t, 1, true> bitset,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int, 2, true> outIndices,
|
||||
int k, T initK) {
|
||||
|
@ -140,15 +149,28 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
|
|||
int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
|
||||
int i = threadIdx.x;
|
||||
|
||||
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
|
||||
T v;
|
||||
|
||||
for (; i < limit; i += blockDim.x) {
|
||||
T v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
} else {
|
||||
v = (T)(1.0 / 0.0);
|
||||
}
|
||||
|
||||
heap.add(v, i);
|
||||
}
|
||||
|
||||
if (i < productDistances.getSize(1)) {
|
||||
T v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
} else {
|
||||
v = (T)(1.0 / 0.0);
|
||||
}
|
||||
|
||||
heap.addThreadQ(v, i);
|
||||
}
|
||||
|
||||
|
@ -159,155 +181,6 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
|
|||
}
|
||||
}
|
||||
|
||||
// With bitset included
|
||||
// L2 + select kernel for k == 1, implements re-use of ||c||^2
|
||||
template <typename T, int kRowsPerBlock, int kBlockSize>
|
||||
__global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 1, true> centroidDistances,
|
||||
Tensor<uint8_t, 1, true> bitset,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int, 2, true> outIndices) {
|
||||
// Each block handles kRowsPerBlock rows of the distances (results)
|
||||
Pair<T, int> threadMin[kRowsPerBlock];
|
||||
__shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
|
||||
|
||||
T distance[kRowsPerBlock];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kRowsPerBlock; ++i) {
|
||||
threadMin[i].k = Limits<T>::getMax();
|
||||
threadMin[i].v = -1;
|
||||
}
|
||||
|
||||
// blockIdx.x: which chunk of rows we are responsible for updating
|
||||
int rowStart = blockIdx.x * kRowsPerBlock;
|
||||
|
||||
// FIXME: if we have exact multiples, don't need this
|
||||
bool endRow = (blockIdx.x == gridDim.x - 1);
|
||||
|
||||
if (endRow) {
|
||||
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
|
||||
endRow = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (endRow) {
|
||||
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
if (!(bitset[col >> 3] & (0x1 << (col & 0x7)))) {
|
||||
distance[0] = Math<T>::add(centroidDistances[col],
|
||||
productDistances[row][col]);
|
||||
|
||||
if (Math<T>::lt(distance[0], threadMin[0].k)) {
|
||||
threadMin[0].k = distance[0];
|
||||
threadMin[0].v = col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
threadMin[0] =
|
||||
blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
|
||||
threadMin[0], Min<Pair<T, int> >(), blockMin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
outDistances[row][0] = threadMin[0].k;
|
||||
outIndices[row][0] = threadMin[0].v;
|
||||
}
|
||||
|
||||
// so we can use the shared memory again
|
||||
__syncthreads();
|
||||
|
||||
threadMin[0].k = Limits<T>::getMax();
|
||||
threadMin[0].v = -1;
|
||||
}
|
||||
} else {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
T centroidDistance = centroidDistances[col];
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
distance[row] = productDistances[rowStart + row][col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
distance[row] = Math<T>::add(distance[row], centroidDistance);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
if (Math<T>::lt(distance[row], threadMin[row].k)) {
|
||||
threadMin[row].k = distance[row];
|
||||
threadMin[row].v = col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
|
||||
threadMin, Min<Pair<T, int> >(), blockMin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
outDistances[rowStart + row][0] = threadMin[row].k;
|
||||
outIndices[rowStart + row][0] = threadMin[row].v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// With bitset included
|
||||
// L2 + select kernel for k > 1, no re-use of ||c||^2
|
||||
template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
|
||||
__global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 1, true> centroidDistances,
|
||||
Tensor<uint8_t, 1, true> bitset,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int, 2, true> outIndices,
|
||||
int k, T initK) {
|
||||
// Each block handles a single row of the distances (results)
|
||||
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
|
||||
|
||||
__shared__ T smemK[kNumWarps * NumWarpQ];
|
||||
__shared__ int smemV[kNumWarps * NumWarpQ];
|
||||
|
||||
BlockSelect<T, int, false, Comparator<T>,
|
||||
NumWarpQ, NumThreadQ, ThreadsPerBlock>
|
||||
heap(initK, -1, smemK, smemV, k);
|
||||
|
||||
int row = blockIdx.x;
|
||||
|
||||
// Whole warps must participate in the selection
|
||||
int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
|
||||
int i = threadIdx.x;
|
||||
|
||||
for (; i < limit; i += blockDim.x) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
T v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
heap.add(v, i);
|
||||
}
|
||||
}
|
||||
|
||||
if (i < productDistances.getSize(1)) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
T v = Math<T>::add(centroidDistances[i],
|
||||
productDistances[row][i]);
|
||||
heap.addThreadQ(v, i);
|
||||
}
|
||||
}
|
||||
|
||||
heap.reduce();
|
||||
for (int i = threadIdx.x; i < k; i += blockDim.x) {
|
||||
outDistances[row][i] = smemK[i];
|
||||
outIndices[row][i] = smemV[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void runL2SelectMin(Tensor<T, 2, true>& productDistances,
|
||||
|
@ -338,14 +211,6 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
|
|||
auto grid = dim3(outDistances.getSize(0));
|
||||
|
||||
#define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
|
||||
do { \
|
||||
l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
|
||||
<<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
|
||||
outDistances, outIndices, \
|
||||
k, Limits<T>::getMax()); \
|
||||
} while (0)
|
||||
|
||||
#define RUN_L2_SELECT_BITSET(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
|
||||
do { \
|
||||
l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
|
||||
<<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
|
||||
|
@ -353,55 +218,27 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
|
|||
k, Limits<T>::getMax()); \
|
||||
} while (0)
|
||||
|
||||
if (bitset.getSize(0) == 0) {
|
||||
// block size 128 for everything <= 1024
|
||||
if (k <= 32) {
|
||||
RUN_L2_SELECT(128, 32, 2);
|
||||
} else if (k <= 64) {
|
||||
RUN_L2_SELECT(128, 64, 3);
|
||||
} else if (k <= 128) {
|
||||
RUN_L2_SELECT(128, 128, 3);
|
||||
} else if (k <= 256) {
|
||||
RUN_L2_SELECT(128, 256, 4);
|
||||
} else if (k <= 512) {
|
||||
RUN_L2_SELECT(128, 512, 8);
|
||||
} else if (k <= 1024) {
|
||||
RUN_L2_SELECT(128, 1024, 8);
|
||||
|
||||
#if GPU_MAX_SELECTION_K >= 2048
|
||||
} else if (k <= 2048) {
|
||||
// smaller block for less shared memory
|
||||
RUN_L2_SELECT(64, 2048, 8);
|
||||
#endif
|
||||
|
||||
} else {
|
||||
FAISS_ASSERT(false);
|
||||
}
|
||||
// block size 128 for everything <= 1024
|
||||
if (k <= 32) {
|
||||
RUN_L2_SELECT(128, 32, 2);
|
||||
} else if (k <= 64) {
|
||||
RUN_L2_SELECT(128, 64, 3);
|
||||
} else if (k <= 128) {
|
||||
RUN_L2_SELECT(128, 128, 3);
|
||||
} else if (k <= 256) {
|
||||
RUN_L2_SELECT(128, 256, 4);
|
||||
} else if (k <= 512) {
|
||||
RUN_L2_SELECT(128, 512, 8);
|
||||
} else if (k <= 1024) {
|
||||
RUN_L2_SELECT(128, 1024, 8);
|
||||
|
||||
#if GPU_MAX_SELECTION_K >= 2048
|
||||
} else if (k <= 2048) {
|
||||
// smaller block for less shared memory
|
||||
RUN_L2_SELECT(64, 2048, 8);
|
||||
#endif
|
||||
} else {
|
||||
// With bitset
|
||||
if (k <= 32) {
|
||||
RUN_L2_SELECT_BITSET(128, 32, 2);
|
||||
} else if (k <= 64) {
|
||||
RUN_L2_SELECT_BITSET(128, 64, 3);
|
||||
} else if (k <= 128) {
|
||||
RUN_L2_SELECT_BITSET(128, 128, 3);
|
||||
} else if (k <= 256) {
|
||||
RUN_L2_SELECT_BITSET(128, 256, 4);
|
||||
} else if (k <= 512) {
|
||||
RUN_L2_SELECT_BITSET(128, 512, 8);
|
||||
} else if (k <= 1024) {
|
||||
RUN_L2_SELECT_BITSET(128, 1024, 8);
|
||||
|
||||
#if GPU_MAX_SELECTION_K >= 2048
|
||||
} else if (k <= 2048) {
|
||||
// smaller block for less shared memory
|
||||
RUN_L2_SELECT_BITSET(64, 2048, 8);
|
||||
#endif
|
||||
|
||||
} else {
|
||||
FAISS_ASSERT(false);
|
||||
}
|
||||
FAISS_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,20 +12,6 @@
|
|||
|
||||
namespace faiss { namespace gpu {
|
||||
|
||||
void runL2SelectMin(Tensor<float, 2, true>& productDistances,
|
||||
Tensor<float, 1, true>& centroidDistances,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices,
|
||||
int k,
|
||||
cudaStream_t stream);
|
||||
|
||||
void runL2SelectMin(Tensor<half, 2, true>& productDistances,
|
||||
Tensor<half, 1, true>& centroidDistances,
|
||||
Tensor<half, 2, true>& outDistances,
|
||||
Tensor<int, 2, true>& outIndices,
|
||||
int k,
|
||||
cudaStream_t stream);
|
||||
|
||||
void runL2SelectMin(Tensor<float, 2, true>& productDistances,
|
||||
Tensor<float, 1, true>& centroidDistances,
|
||||
Tensor<uint8_t, 1, true>& bitset,
|
||||
|
|
|
@ -142,17 +142,24 @@ __global__ void blockSelect(Tensor<K, 2, true> in,
|
|||
// Whole warps must participate in the selection
|
||||
int limit = utils::roundDown(in.getSize(1), kWarpSize);
|
||||
|
||||
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
|
||||
|
||||
for (; i < limit; i += ThreadsPerBlock) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
heap.add(*inStart, (IndexType) i);
|
||||
inStart += ThreadsPerBlock;
|
||||
} else {
|
||||
heap.add(-1.0, (IndexType) i);
|
||||
}
|
||||
|
||||
inStart += ThreadsPerBlock;
|
||||
}
|
||||
|
||||
// Handle last remainder fraction of a warp of elements
|
||||
if (i < in.getSize(1)) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
heap.addThreadQ(*inStart, (IndexType) i);
|
||||
} else {
|
||||
heap.addThreadQ(-1.0, (IndexType) i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,18 +204,25 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
|
|||
// Whole warps must participate in the selection
|
||||
int limit = utils::roundDown(inK.getSize(1), kWarpSize);
|
||||
|
||||
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
|
||||
|
||||
for (; i < limit; i += ThreadsPerBlock) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
heap.add(*inKStart, *inVStart);
|
||||
inKStart += ThreadsPerBlock;
|
||||
inVStart += ThreadsPerBlock;
|
||||
} else {
|
||||
heap.add(-1.0, *inVStart);
|
||||
}
|
||||
|
||||
inKStart += ThreadsPerBlock;
|
||||
inVStart += ThreadsPerBlock;
|
||||
}
|
||||
|
||||
// Handle last remainder fraction of a warp of elements
|
||||
if (i < inK.getSize(1)) {
|
||||
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
|
||||
if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
|
||||
heap.addThreadQ(*inKStart, *inVStart);
|
||||
} else {
|
||||
heap.addThreadQ(-1.0, *inVStart);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,40 +9,61 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <fiu-control.h>
|
||||
#include <fiu-local.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexType.h"
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
|
||||
#include "knowhere/index/vector_index/helpers/Cloner.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
|
||||
#endif
|
||||
#include "Helper.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
class IDMAPTest : public DataGen, public TestGpuIndexBase {
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class IDMAPTest : public DataGen, public TestWithParam<milvus::knowhere::IndexMode> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
TestGpuIndexBase::SetUp();
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
|
||||
#endif
|
||||
index_mode_ = GetParam();
|
||||
Init_with_default();
|
||||
index_ = std::make_shared<milvus::knowhere::IDMAP>();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
TestGpuIndexBase::TearDown();
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free();
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::IDMAPPtr index_ = nullptr;
|
||||
milvus::knowhere::IndexMode index_mode_;
|
||||
};
|
||||
|
||||
TEST_F(IDMAPTest, idmap_basic) {
|
||||
INSTANTIATE_TEST_CASE_P(IDMAPParameters, IDMAPTest,
|
||||
Values(
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
milvus::knowhere::IndexMode::MODE_GPU,
|
||||
#endif
|
||||
milvus::knowhere::IndexMode::MODE_CPU));
|
||||
|
||||
TEST_P(IDMAPTest, idmap_basic) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||
|
@ -67,6 +88,13 @@ TEST_F(IDMAPTest, idmap_basic) {
|
|||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
// cpu to gpu
|
||||
index_ = std::dynamic_pointer_cast<milvus::knowhere::IDMAP>(index_->CopyCpuToGpu(DEVICEID, conf));
|
||||
#endif
|
||||
}
|
||||
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<milvus::knowhere::IDMAP>();
|
||||
new_index->Load(binaryset);
|
||||
|
@ -96,7 +124,7 @@ TEST_F(IDMAPTest, idmap_basic) {
|
|||
AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL);
|
||||
}
|
||||
|
||||
TEST_F(IDMAPTest, idmap_serialize) {
|
||||
TEST_P(IDMAPTest, idmap_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
|
@ -113,6 +141,14 @@ TEST_F(IDMAPTest, idmap_serialize) {
|
|||
// serialize index
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, milvus::knowhere::Config());
|
||||
|
||||
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
// cpu to gpu
|
||||
index_ = std::dynamic_pointer_cast<milvus::knowhere::IDMAP>(index_->CopyCpuToGpu(DEVICEID, conf));
|
||||
#endif
|
||||
}
|
||||
|
||||
auto re_result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
|
@ -139,7 +175,7 @@ TEST_F(IDMAPTest, idmap_serialize) {
|
|||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
TEST_F(IDMAPTest, copy_test) {
|
||||
TEST_P(IDMAPTest, copy_test) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||
|
|
Loading…
Reference in New Issue