Add Unittest for IDMAP (#1774)

* gpu idmap unittest

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix clang format

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix build issue

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* small fix

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix case

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix back

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* Compact code

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* efficient code

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* Fix GPU search hang

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
pull/1803/head
Xiaohai Xu 2020-03-30 14:08:56 +08:00 committed by GitHub
parent 4574f9d998
commit 8a731dea39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 114 additions and 290 deletions

View File

@ -105,7 +105,7 @@ GPUIDMAP::GetRawIds() {
void void
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
ResScope rs(res_, gpu_id_); ResScope rs(res_, gpu_id_);
index_->search(n, (float*)data, k, distances, labels); index_->search(n, (float*)data, k, distances, labels, bitset_);
} }
void void

View File

@ -310,7 +310,7 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k, float *distances, idx_t
invlists->prefetch_lists (idx.get(), n * nprobe); invlists->prefetch_lists (idx.get(), n * nprobe);
search_preassigned (n, x, k, idx.get(), coarse_dis.get(), search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
distances, labels, false, nullptr,bitset); distances, labels, false, nullptr, bitset);
indexIVF_stats.search_time += getmillisecs() - t0; indexIVF_stats.search_time += getmillisecs() - t0;
} }

View File

@ -18,55 +18,6 @@ class GpuResources;
/// Calculates brute-force L2 distance between `vectors` and /// Calculates brute-force L2 distance between `vectors` and
/// `queries`, returning the k closest results seen /// `queries`, returning the k closest results seen
void runL2Distance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
bool vectorsRowMajor,
// can be optionally pre-computed; nullptr if we
// have to compute it upon the call
Tensor<float, 1, true>* vectorNorms,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
// Do we care about `outDistances`? If not, we can
// take shortcuts.
bool ignoreOutDistances = false);
/// Calculates brute-force inner product distance between `vectors`
/// and `queries`, returning the k closest results seen
void runIPDistance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices);
void runIPDistance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm);
void runL2Distance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<half, 1, true>* vectorNorms,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm,
bool ignoreOutDistances = false);
// Bitset added
void runL2Distance(GpuResources* resources, void runL2Distance(GpuResources* resources,
Tensor<float, 2, true>& vectors, Tensor<float, 2, true>& vectors,
bool vectorsRowMajor, bool vectorsRowMajor,

View File

@ -24,6 +24,7 @@ namespace faiss { namespace gpu {
template <typename T, int kRowsPerBlock, int kBlockSize> template <typename T, int kRowsPerBlock, int kBlockSize>
__global__ void l2SelectMin1(Tensor<T, 2, true> productDistances, __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances, Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances, Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices) { Tensor<int, 2, true> outIndices) {
// Each block handles kRowsPerBlock rows of the distances (results) // Each block handles kRowsPerBlock rows of the distances (results)
@ -44,6 +45,8 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
// FIXME: if we have exact multiples, don't need this // FIXME: if we have exact multiples, don't need this
bool endRow = (blockIdx.x == gridDim.x - 1); bool endRow = (blockIdx.x == gridDim.x - 1);
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
if (endRow) { if (endRow) {
if (productDistances.getSize(0) % kRowsPerBlock == 0) { if (productDistances.getSize(0) % kRowsPerBlock == 0) {
endRow = false; endRow = false;
@ -54,8 +57,12 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
for (int row = rowStart; row < productDistances.getSize(0); ++row) { for (int row = rowStart; row < productDistances.getSize(0); ++row) {
for (int col = threadIdx.x; col < productDistances.getSize(1); for (int col = threadIdx.x; col < productDistances.getSize(1);
col += blockDim.x) { col += blockDim.x) {
distance[0] = Math<T>::add(centroidDistances[col], if (bitsetIsEmpty || (!(bitset[col >> 3] & (0x1 << (col & 0x7))))) {
productDistances[row][col]); distance[0] = Math<T>::add(centroidDistances[col],
productDistances[row][col]);
} else {
distance[0] = (T)(1.0 / 0.0);
}
if (Math<T>::lt(distance[0], threadMin[0].k)) { if (Math<T>::lt(distance[0], threadMin[0].k)) {
threadMin[0].k = distance[0]; threadMin[0].k = distance[0];
@ -117,10 +124,12 @@ __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
} }
} }
// With bitset included
// L2 + select kernel for k > 1, no re-use of ||c||^2 // L2 + select kernel for k > 1, no re-use of ||c||^2
template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock> template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
__global__ void l2SelectMinK(Tensor<T, 2, true> productDistances, __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances, Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances, Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices, Tensor<int, 2, true> outIndices,
int k, T initK) { int k, T initK) {
@ -140,15 +149,28 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
int limit = utils::roundDown(productDistances.getSize(1), kWarpSize); int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
int i = threadIdx.x; int i = threadIdx.x;
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
T v;
for (; i < limit; i += blockDim.x) { for (; i < limit; i += blockDim.x) {
T v = Math<T>::add(centroidDistances[i], if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
productDistances[row][i]); v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
} else {
v = (T)(1.0 / 0.0);
}
heap.add(v, i); heap.add(v, i);
} }
if (i < productDistances.getSize(1)) { if (i < productDistances.getSize(1)) {
T v = Math<T>::add(centroidDistances[i], if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
productDistances[row][i]); v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
} else {
v = (T)(1.0 / 0.0);
}
heap.addThreadQ(v, i); heap.addThreadQ(v, i);
} }
@ -159,155 +181,6 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
} }
} }
// With bitset included
// L2 + select kernel for k == 1, implements re-use of ||c||^2
template <typename T, int kRowsPerBlock, int kBlockSize>
__global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices) {
// Each block handles kRowsPerBlock rows of the distances (results)
Pair<T, int> threadMin[kRowsPerBlock];
__shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
T distance[kRowsPerBlock];
#pragma unroll
for (int i = 0; i < kRowsPerBlock; ++i) {
threadMin[i].k = Limits<T>::getMax();
threadMin[i].v = -1;
}
// blockIdx.x: which chunk of rows we are responsible for updating
int rowStart = blockIdx.x * kRowsPerBlock;
// FIXME: if we have exact multiples, don't need this
bool endRow = (blockIdx.x == gridDim.x - 1);
if (endRow) {
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
endRow = false;
}
}
if (endRow) {
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
for (int col = threadIdx.x; col < productDistances.getSize(1);
col += blockDim.x) {
if (!(bitset[col >> 3] & (0x1 << (col & 0x7)))) {
distance[0] = Math<T>::add(centroidDistances[col],
productDistances[row][col]);
if (Math<T>::lt(distance[0], threadMin[0].k)) {
threadMin[0].k = distance[0];
threadMin[0].v = col;
}
}
}
// Reduce within the block
threadMin[0] =
blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
threadMin[0], Min<Pair<T, int> >(), blockMin);
if (threadIdx.x == 0) {
outDistances[row][0] = threadMin[0].k;
outIndices[row][0] = threadMin[0].v;
}
// so we can use the shared memory again
__syncthreads();
threadMin[0].k = Limits<T>::getMax();
threadMin[0].v = -1;
}
} else {
for (int col = threadIdx.x; col < productDistances.getSize(1);
col += blockDim.x) {
T centroidDistance = centroidDistances[col];
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
distance[row] = productDistances[rowStart + row][col];
}
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
distance[row] = Math<T>::add(distance[row], centroidDistance);
}
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
if (Math<T>::lt(distance[row], threadMin[row].k)) {
threadMin[row].k = distance[row];
threadMin[row].v = col;
}
}
}
// Reduce within the block
blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
threadMin, Min<Pair<T, int> >(), blockMin);
if (threadIdx.x == 0) {
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
outDistances[rowStart + row][0] = threadMin[row].k;
outIndices[rowStart + row][0] = threadMin[row].v;
}
}
}
}
// With bitset included
// L2 + select kernel for k > 1, no re-use of ||c||^2
template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
__global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices,
int k, T initK) {
// Each block handles a single row of the distances (results)
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
__shared__ T smemK[kNumWarps * NumWarpQ];
__shared__ int smemV[kNumWarps * NumWarpQ];
BlockSelect<T, int, false, Comparator<T>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(initK, -1, smemK, smemV, k);
int row = blockIdx.x;
// Whole warps must participate in the selection
int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
int i = threadIdx.x;
for (; i < limit; i += blockDim.x) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
T v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
heap.add(v, i);
}
}
if (i < productDistances.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
T v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
heap.addThreadQ(v, i);
}
}
heap.reduce();
for (int i = threadIdx.x; i < k; i += blockDim.x) {
outDistances[row][i] = smemK[i];
outIndices[row][i] = smemV[i];
}
}
template <typename T> template <typename T>
void runL2SelectMin(Tensor<T, 2, true>& productDistances, void runL2SelectMin(Tensor<T, 2, true>& productDistances,
@ -338,14 +211,6 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
auto grid = dim3(outDistances.getSize(0)); auto grid = dim3(outDistances.getSize(0));
#define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \ #define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
do { \
l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
<<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
outDistances, outIndices, \
k, Limits<T>::getMax()); \
} while (0)
#define RUN_L2_SELECT_BITSET(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
do { \ do { \
l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \ l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
<<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \ <<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
@ -353,55 +218,27 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
k, Limits<T>::getMax()); \ k, Limits<T>::getMax()); \
} while (0) } while (0)
if (bitset.getSize(0) == 0) { // block size 128 for everything <= 1024
// block size 128 for everything <= 1024 if (k <= 32) {
if (k <= 32) { RUN_L2_SELECT(128, 32, 2);
RUN_L2_SELECT(128, 32, 2); } else if (k <= 64) {
} else if (k <= 64) { RUN_L2_SELECT(128, 64, 3);
RUN_L2_SELECT(128, 64, 3); } else if (k <= 128) {
} else if (k <= 128) { RUN_L2_SELECT(128, 128, 3);
RUN_L2_SELECT(128, 128, 3); } else if (k <= 256) {
} else if (k <= 256) { RUN_L2_SELECT(128, 256, 4);
RUN_L2_SELECT(128, 256, 4); } else if (k <= 512) {
} else if (k <= 512) { RUN_L2_SELECT(128, 512, 8);
RUN_L2_SELECT(128, 512, 8); } else if (k <= 1024) {
} else if (k <= 1024) { RUN_L2_SELECT(128, 1024, 8);
RUN_L2_SELECT(128, 1024, 8);
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT(64, 2048, 8);
#endif
} else {
FAISS_ASSERT(false);
}
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT(64, 2048, 8);
#endif
} else { } else {
// With bitset FAISS_ASSERT(false);
if (k <= 32) {
RUN_L2_SELECT_BITSET(128, 32, 2);
} else if (k <= 64) {
RUN_L2_SELECT_BITSET(128, 64, 3);
} else if (k <= 128) {
RUN_L2_SELECT_BITSET(128, 128, 3);
} else if (k <= 256) {
RUN_L2_SELECT_BITSET(128, 256, 4);
} else if (k <= 512) {
RUN_L2_SELECT_BITSET(128, 512, 8);
} else if (k <= 1024) {
RUN_L2_SELECT_BITSET(128, 1024, 8);
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT_BITSET(64, 2048, 8);
#endif
} else {
FAISS_ASSERT(false);
}
} }
} }

View File

@ -12,20 +12,6 @@
namespace faiss { namespace gpu { namespace faiss { namespace gpu {
void runL2SelectMin(Tensor<float, 2, true>& productDistances,
Tensor<float, 1, true>& centroidDistances,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream);
void runL2SelectMin(Tensor<half, 2, true>& productDistances,
Tensor<half, 1, true>& centroidDistances,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream);
void runL2SelectMin(Tensor<float, 2, true>& productDistances, void runL2SelectMin(Tensor<float, 2, true>& productDistances,
Tensor<float, 1, true>& centroidDistances, Tensor<float, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset, Tensor<uint8_t, 1, true>& bitset,

View File

@ -142,17 +142,24 @@ __global__ void blockSelect(Tensor<K, 2, true> in,
// Whole warps must participate in the selection // Whole warps must participate in the selection
int limit = utils::roundDown(in.getSize(1), kWarpSize); int limit = utils::roundDown(in.getSize(1), kWarpSize);
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
for (; i < limit; i += ThreadsPerBlock) { for (; i < limit; i += ThreadsPerBlock) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) { if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.add(*inStart, (IndexType) i); heap.add(*inStart, (IndexType) i);
inStart += ThreadsPerBlock; } else {
heap.add(-1.0, (IndexType) i);
} }
inStart += ThreadsPerBlock;
} }
// Handle last remainder fraction of a warp of elements // Handle last remainder fraction of a warp of elements
if (i < in.getSize(1)) { if (i < in.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) { if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.addThreadQ(*inStart, (IndexType) i); heap.addThreadQ(*inStart, (IndexType) i);
} else {
heap.addThreadQ(-1.0, (IndexType) i);
} }
} }
@ -197,18 +204,25 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
// Whole warps must participate in the selection // Whole warps must participate in the selection
int limit = utils::roundDown(inK.getSize(1), kWarpSize); int limit = utils::roundDown(inK.getSize(1), kWarpSize);
bool bitsetIsEmpty = (bitset.getSize(0) == 0);
for (; i < limit; i += ThreadsPerBlock) { for (; i < limit; i += ThreadsPerBlock) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) { if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.add(*inKStart, *inVStart); heap.add(*inKStart, *inVStart);
inKStart += ThreadsPerBlock; } else {
inVStart += ThreadsPerBlock; heap.add(-1.0, *inVStart);
} }
inKStart += ThreadsPerBlock;
inVStart += ThreadsPerBlock;
} }
// Handle last remainder fraction of a warp of elements // Handle last remainder fraction of a warp of elements
if (i < inK.getSize(1)) { if (i < inK.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) { if (bitsetIsEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.addThreadQ(*inKStart, *inVStart); heap.addThreadQ(*inKStart, *inVStart);
} else {
heap.addThreadQ(-1.0, *inVStart);
} }
} }

View File

@ -9,40 +9,61 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License. // or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <fiu-control.h> #include <fiu-control.h>
#include <fiu-local.h> #include <fiu-local.h>
#include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <thread>
#include "knowhere/common/Exception.h" #include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexType.h"
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
#include <faiss/gpu/GpuCloner.h>
#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" #include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h"
#include "knowhere/index/vector_index/helpers/Cloner.h" #include "knowhere/index/vector_index/helpers/Cloner.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif #endif
#include "Helper.h" #include "Helper.h"
#include "unittest/utils.h" #include "unittest/utils.h"
class IDMAPTest : public DataGen, public TestGpuIndexBase { using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class IDMAPTest : public DataGen, public TestWithParam<milvus::knowhere::IndexMode> {
protected: protected:
void void
SetUp() override { SetUp() override {
TestGpuIndexBase::SetUp(); #ifdef MILVUS_GPU_VERSION
milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
#endif
index_mode_ = GetParam();
Init_with_default(); Init_with_default();
index_ = std::make_shared<milvus::knowhere::IDMAP>(); index_ = std::make_shared<milvus::knowhere::IDMAP>();
} }
void void
TearDown() override { TearDown() override {
TestGpuIndexBase::TearDown(); #ifdef MILVUS_GPU_VERSION
milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free();
#endif
} }
protected: protected:
milvus::knowhere::IDMAPPtr index_ = nullptr; milvus::knowhere::IDMAPPtr index_ = nullptr;
milvus::knowhere::IndexMode index_mode_;
}; };
TEST_F(IDMAPTest, idmap_basic) { INSTANTIATE_TEST_CASE_P(IDMAPParameters, IDMAPTest,
Values(
#ifdef MILVUS_GPU_VERSION
milvus::knowhere::IndexMode::MODE_GPU,
#endif
milvus::knowhere::IndexMode::MODE_CPU));
TEST_P(IDMAPTest, idmap_basic) {
ASSERT_TRUE(!xb.empty()); ASSERT_TRUE(!xb.empty());
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim}, milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
@ -67,6 +88,13 @@ TEST_F(IDMAPTest, idmap_basic) {
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
#ifdef MILVUS_GPU_VERSION
// cpu to gpu
index_ = std::dynamic_pointer_cast<milvus::knowhere::IDMAP>(index_->CopyCpuToGpu(DEVICEID, conf));
#endif
}
auto binaryset = index_->Serialize(); auto binaryset = index_->Serialize();
auto new_index = std::make_shared<milvus::knowhere::IDMAP>(); auto new_index = std::make_shared<milvus::knowhere::IDMAP>();
new_index->Load(binaryset); new_index->Load(binaryset);
@ -96,7 +124,7 @@ TEST_F(IDMAPTest, idmap_basic) {
AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL); AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL);
} }
TEST_F(IDMAPTest, idmap_serialize) { TEST_P(IDMAPTest, idmap_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
FileIOWriter writer(filename); FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size); writer(static_cast<void*>(bin->data.get()), bin->size);
@ -113,6 +141,14 @@ TEST_F(IDMAPTest, idmap_serialize) {
// serialize index // serialize index
index_->Train(base_dataset, conf); index_->Train(base_dataset, conf);
index_->Add(base_dataset, milvus::knowhere::Config()); index_->Add(base_dataset, milvus::knowhere::Config());
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
#ifdef MILVUS_GPU_VERSION
// cpu to gpu
index_ = std::dynamic_pointer_cast<milvus::knowhere::IDMAP>(index_->CopyCpuToGpu(DEVICEID, conf));
#endif
}
auto re_result = index_->Query(query_dataset, conf); auto re_result = index_->Query(query_dataset, conf);
AssertAnns(re_result, nq, k); AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k); // PrintResult(re_result, nq, k);
@ -139,7 +175,7 @@ TEST_F(IDMAPTest, idmap_serialize) {
} }
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
TEST_F(IDMAPTest, copy_test) { TEST_P(IDMAPTest, copy_test) {
ASSERT_TRUE(!xb.empty()); ASSERT_TRUE(!xb.empty());
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim}, milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},