#1654 GPU Index Flat Delete (#1736)

* Add Flat Index Delete

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* Fix log

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* Fix bitset

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* Fix reference

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix bug

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix bug

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix bug

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
pull/1769/head
Xiaohai Xu 2020-03-26 18:36:46 +08:00 committed by GitHub
parent 65ffaedae3
commit f93b464172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 565 additions and 65 deletions

View File

@ -64,6 +64,9 @@ void bruteForceKnn(GpuResources* resources,
// temporary memory for it
DeviceTensor<int, 2, true> tOutIntIndices(mem, {numQueries, k}, stream);
// Empty bitset
auto bitsetDevice = toDevice<uint8_t, 1>(resources, device, nullptr, stream, {0});
// Do the work
if (metric == faiss::MetricType::METRIC_L2) {
runL2Distance(resources,
@ -72,6 +75,7 @@ void bruteForceKnn(GpuResources* resources,
nullptr, // compute norms in temp memory
tQueries,
queriesRowMajor,
bitsetDevice,
k,
tOutDistances,
tOutIntIndices);
@ -81,6 +85,7 @@ void bruteForceKnn(GpuResources* resources,
vectorsRowMajor,
tQueries,
queriesRowMajor,
bitsetDevice,
k,
tOutDistances,
tOutIntIndices);

View File

@ -235,7 +235,8 @@ GpuIndex::search(Index::idx_t n,
if (dataSize >= minPagedSize_) {
searchFromCpuPaged_(n, x, k,
outDistances.data(),
outLabels.data());
outLabels.data(),
bitset);
usePaged = true;
}
}
@ -243,7 +244,8 @@ GpuIndex::search(Index::idx_t n,
if (!usePaged) {
searchNonPaged_(n, x, k,
outDistances.data(),
outLabels.data());
outLabels.data(),
bitset);
}
// Copy back if necessary
@ -256,7 +258,8 @@ GpuIndex::searchNonPaged_(int n,
const float* x,
int k,
float* outDistancesData,
Index::idx_t* outIndicesData) const {
Index::idx_t* outIndicesData,
ConcurrentBitsetPtr bitset) const {
auto stream = resources_->getDefaultStream(device_);
// Make sure arguments are on the device we desire; use temporary
@ -267,7 +270,7 @@ GpuIndex::searchNonPaged_(int n,
stream,
{n, (int) this->d});
searchImpl_(n, vecs.data(), k, outDistancesData, outIndicesData);
searchImpl_(n, vecs.data(), k, outDistancesData, outIndicesData, bitset);
}
void
@ -275,7 +278,8 @@ GpuIndex::searchFromCpuPaged_(int n,
const float* x,
int k,
float* outDistancesData,
Index::idx_t* outIndicesData) const {
Index::idx_t* outIndicesData,
ConcurrentBitsetPtr bitset) const {
Tensor<float, 2, true> outDistances(outDistancesData, {n, k});
Tensor<Index::idx_t, 2, true> outIndices(outIndicesData, {n, k});
@ -300,7 +304,8 @@ GpuIndex::searchFromCpuPaged_(int n,
x + (size_t) cur * this->d,
k,
outDistancesSlice.data(),
outIndicesSlice.data());
outIndicesSlice.data(),
bitset);
}
return;
@ -411,7 +416,8 @@ GpuIndex::searchFromCpuPaged_(int n,
bufGpus[cur3BufIndex]->data(),
k,
outDistancesSlice.data(),
outIndicesSlice.data());
outIndicesSlice.data(),
bitset);
// Create completion event
eventGpuExecuteDone[cur3BufIndex] =

View File

@ -103,7 +103,8 @@ class GpuIndex : public faiss::Index {
const float* x,
int k,
float* distances,
Index::idx_t* labels) const = 0;
Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const = 0;
private:
/// Handles paged adds if the add set is too large, passes to
@ -122,7 +123,8 @@ private:
const float* x,
int k,
float* outDistancesData,
Index::idx_t* outIndicesData) const;
Index::idx_t* outIndicesData,
ConcurrentBitsetPtr bitset = nullptr) const;
/// Calls searchImpl_ for a single page of GPU-resident data,
/// handling paging of the data and copies from the CPU
@ -130,7 +132,8 @@ private:
const float* x,
int k,
float* outDistancesData,
Index::idx_t* outIndicesData) const;
Index::idx_t* outIndicesData,
ConcurrentBitsetPtr bitset = nullptr) const;
protected:
/// Manages streams, cuBLAS handles and scratch memory for devices

View File

@ -203,7 +203,8 @@ GpuIndexFlat::searchImpl_(int n,
const float* x,
int k,
float* distances,
Index::idx_t* labels) const {
Index::idx_t* labels,
ConcurrentBitsetPtr bitset) const {
auto stream = resources_->getDefaultStream(device_);
// Input and output data are already resident on the GPU
@ -215,7 +216,17 @@ GpuIndexFlat::searchImpl_(int n,
DeviceTensor<int, 2, true> outIntLabels(
resources_->getMemoryManagerCurrentDevice(), {n, k}, stream);
data_->query(queries, k, outDistances, outIntLabels, true);
// Copy bitset to GPU
if (!bitset) {
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
data_->query(queries, bitsetDevice, k, outDistances, outIntLabels, true);
} else {
auto bitsetData = bitset->bitset();
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_,
const_cast<uint8_t*>(bitsetData), stream,
{(int) bitset->size()});
data_->query(queries, bitsetDevice, k, outDistances, outIntLabels, true);
}
// Convert int to idx_t
convertTensor<int, faiss::Index::idx_t, 2>(stream,

View File

@ -126,7 +126,8 @@ class GpuIndexFlat : public GpuIndex {
const float* x,
int k,
float* distances,
faiss::Index::idx_t* labels) const override;
faiss::Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
private:
/// Checks user settings for consistency

View File

@ -207,14 +207,18 @@ GpuIndexIVFFlat::addImpl_(int n,
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> data(const_cast<float*>(x), {n, (int) this->d});
auto bitset = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 1, true> labels(const_cast<long*>(xids), {n});
// Not all vectors may be able to be added (some may contain NaNs etc)
index_->classifyAndAddVectors(data, labels);
index_->classifyAndAddVectors(data, labels, bitset);
// but keep the ntotal based on the total number of vectors that we attempted
// to add
@ -226,11 +230,14 @@ GpuIndexIVFFlat::searchImpl_(int n,
const float* x,
int k,
float* distances,
Index::idx_t* labels) const {
Index::idx_t* labels,
ConcurrentBitsetPtr bitset) const {
// Device is already set in GpuIndex::search
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
Tensor<float, 2, true> outDistances(distances, {n, k});
@ -238,7 +245,9 @@ GpuIndexIVFFlat::searchImpl_(int n,
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 2, true> outLabels(const_cast<long*>(labels), {n, k});
index_->query(queries, nprobe, k, outDistances, outLabels);
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels);
}

View File

@ -70,7 +70,8 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
const float* x,
int k,
float* distances,
Index::idx_t* labels) const override;
Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
private:
GpuIndexIVFFlatConfig ivfFlatConfig_;

View File

@ -330,14 +330,18 @@ GpuIndexIVFPQ::addImpl_(int n,
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> data(const_cast<float*>(x), {n, (int) this->d});
auto bitset = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 1, true> labels(const_cast<long*>(xids), {n});
// Not all vectors may be able to be added (some may contain NaNs etc)
index_->classifyAndAddVectors(data, labels);
index_->classifyAndAddVectors(data, labels, bitset);
// but keep the ntotal based on the total number of vectors that we attempted
// to add
@ -349,11 +353,14 @@ GpuIndexIVFPQ::searchImpl_(int n,
const float* x,
int k,
float* distances,
Index::idx_t* labels) const {
Index::idx_t* labels,
ConcurrentBitsetPtr bitset) const {
// Device is already set in GpuIndex::search
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
Tensor<float, 2, true> outDistances(distances, {n, k});
@ -361,7 +368,9 @@ GpuIndexIVFPQ::searchImpl_(int n,
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 2, true> outLabels(const_cast<long*>(labels), {n, k});
index_->query(queries, nprobe, k, outDistances, outLabels);
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels);
}
int

View File

@ -116,7 +116,8 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
const float* x,
int k,
float* distances,
Index::idx_t* labels) const override;
Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
private:
void verifySettings_() const;

View File

@ -304,14 +304,18 @@ GpuIndexIVFSQHybrid::addImpl_(int n,
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> data(const_cast<float*>(x), {n, (int) this->d});
auto bitset = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 1, true> labels(const_cast<long*>(xids), {n});
// Not all vectors may be able to be added (some may contain NaNs etc)
index_->classifyAndAddVectors(data, labels);
index_->classifyAndAddVectors(data, labels, bitset);
// but keep the ntotal based on the total number of vectors that we attempted
// to add
@ -323,11 +327,14 @@ GpuIndexIVFSQHybrid::searchImpl_(int n,
const float* x,
int k,
float* distances,
Index::idx_t* labels) const {
Index::idx_t* labels,
ConcurrentBitsetPtr bitset) const {
// Device is already set in GpuIndex::search
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
Tensor<float, 2, true> outDistances(distances, {n, k});
@ -335,7 +342,9 @@ GpuIndexIVFSQHybrid::searchImpl_(int n,
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 2, true> outLabels(const_cast<long*>(labels), {n, k});
index_->query(queries, nprobe, k, outDistances, outLabels);
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels);
}
} } // namespace

View File

@ -79,7 +79,8 @@ class GpuIndexIVFSQHybrid : public GpuIndexIVF {
const float* x,
int k,
float* distances,
Index::idx_t* labels) const override;
Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
/// Called from train to handle SQ residual training
void trainResiduals_(Index::idx_t n, const float* x);

View File

@ -239,14 +239,18 @@ GpuIndexIVFScalarQuantizer::addImpl_(int n,
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> data(const_cast<float*>(x), {n, (int) this->d});
auto bitset = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 1, true> labels(const_cast<long*>(xids), {n});
// Not all vectors may be able to be added (some may contain NaNs etc)
index_->classifyAndAddVectors(data, labels);
index_->classifyAndAddVectors(data, labels, bitset);
// but keep the ntotal based on the total number of vectors that we attempted
// to add
@ -258,11 +262,14 @@ GpuIndexIVFScalarQuantizer::searchImpl_(int n,
const float* x,
int k,
float* distances,
Index::idx_t* labels) const {
Index::idx_t* labels,
ConcurrentBitsetPtr bitset) const {
// Device is already set in GpuIndex::search
FAISS_ASSERT(index_);
FAISS_ASSERT(n > 0);
auto stream = resources_->getDefaultStream(device_);
// Data is already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
Tensor<float, 2, true> outDistances(distances, {n, k});
@ -270,7 +277,9 @@ GpuIndexIVFScalarQuantizer::searchImpl_(int n,
static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch");
Tensor<long, 2, true> outLabels(const_cast<long*>(labels), {n, k});
index_->query(queries, nprobe, k, outDistances, outLabels);
auto bitsetDevice = toDevice<uint8_t, 1>(resources_, device_, nullptr, stream, {0});
index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels);
}
} } // namespace

View File

@ -75,7 +75,8 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF {
const float* x,
int k,
float* distances,
Index::idx_t* labels) const override;
Index::idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
/// Called from train to handle SQ residual training
void trainResiduals_(Index::idx_t n, const float* x);

View File

@ -130,6 +130,7 @@ void runDistance(bool computeL2,
Tensor<T, 1, true>* centroidNorms,
Tensor<T, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -306,6 +307,7 @@ void runDistance(bool computeL2,
// Write into the final output
runL2SelectMin(distanceBufView,
*centroidNorms,
bitset,
outDistanceView,
outIndexView,
k,
@ -326,6 +328,7 @@ void runDistance(bool computeL2,
// Write into our intermediate output
runL2SelectMin(distanceBufView,
centroidNormsView,
bitset,
outDistanceBufColView,
outIndexBufColView,
k,
@ -346,12 +349,14 @@ void runDistance(bool computeL2,
if (tileCols == numCentroids) {
// Write into the final output
runBlockSelect(distanceBufView,
bitset,
outDistanceView,
outIndexView,
true, k, streams[curStream]);
} else {
// Write into the intermediate output
runBlockSelect(distanceBufView,
bitset,
outDistanceBufColView,
outIndexBufColView,
true, k, streams[curStream]);
@ -368,6 +373,7 @@ void runDistance(bool computeL2,
runBlockSelectPair(outDistanceBufRowView,
outIndexBufRowView,
bitset,
outDistanceView,
outIndexView,
computeL2 ? false : true, k, streams[curStream]);
@ -384,6 +390,7 @@ void runDistance(bool computeL2,
}
}
// Bitset added
template <typename T>
void runL2Distance(GpuResources* resources,
Tensor<T, 2, true>& centroids,
@ -391,6 +398,7 @@ void runL2Distance(GpuResources* resources,
Tensor<T, 1, true>* centroidNorms,
Tensor<T, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -403,6 +411,7 @@ void runL2Distance(GpuResources* resources,
centroidNorms,
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,
@ -416,6 +425,7 @@ void runIPDistance(GpuResources* resources,
bool centroidsRowMajor,
Tensor<T, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -427,6 +437,7 @@ void runIPDistance(GpuResources* resources,
nullptr, // no centroid norms provided
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,
@ -444,6 +455,7 @@ runIPDistance(GpuResources* resources,
bool vectorsRowMajor,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices) {
@ -452,6 +464,7 @@ runIPDistance(GpuResources* resources,
vectorsRowMajor,
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,
@ -464,6 +477,7 @@ runIPDistance(GpuResources* resources,
bool vectorsRowMajor,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -473,6 +487,7 @@ runIPDistance(GpuResources* resources,
vectorsRowMajor,
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,
@ -486,6 +501,7 @@ runL2Distance(GpuResources* resources,
Tensor<float, 1, true>* vectorNorms,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -496,6 +512,7 @@ runL2Distance(GpuResources* resources,
vectorNorms,
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,
@ -510,6 +527,7 @@ runL2Distance(GpuResources* resources,
Tensor<half, 1, true>* vectorNorms,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -521,6 +539,7 @@ runL2Distance(GpuResources* resources,
vectorNorms,
queries,
queriesRowMajor,
bitset,
k,
outDistances,
outIndices,

View File

@ -10,6 +10,7 @@
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/utils/ConcurrentBitset.h>
namespace faiss { namespace gpu {
@ -65,4 +66,58 @@ void runL2Distance(GpuResources* resources,
bool useHgemm,
bool ignoreOutDistances = false);
// Bitset added
void runL2Distance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
bool vectorsRowMajor,
// can be optionally pre-computed; nullptr if we
// have to compute it upon the call
Tensor<float, 1, true>* vectorNorms,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
// Do we care about `outDistances`? If not, we can
// take shortcuts.
bool ignoreOutDistances = false);
/// Calculates brute-force inner product distance between `vectors`
/// and `queries`, returning the k closest results seen
void runIPDistance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices);
void runIPDistance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm);
void runL2Distance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<half, 1, true>* vectorNorms,
Tensor<half, 2, true>& queries,
bool queriesRowMajor,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm,
bool ignoreOutDistances = false);
} } // namespace

View File

@ -103,6 +103,7 @@ FlatIndex::getVectorsFloat32Copy(int from, int num, cudaStream_t stream) {
void
FlatIndex::query(Tensor<float, 2, true>& input,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -119,7 +120,7 @@ FlatIndex::query(Tensor<float, 2, true>& input,
DeviceTensor<half, 2, true> outDistancesHalf(
mem, {outDistances.getSize(0), outDistances.getSize(1)}, stream);
query(inputHalf, k, outDistancesHalf, outIndices, exactDistance);
query(inputHalf, bitset, k, outDistancesHalf, outIndices, exactDistance);
if (exactDistance) {
// Convert outDistances back
@ -135,6 +136,7 @@ FlatIndex::query(Tensor<float, 2, true>& input,
&norms_,
input,
true, // input is row major
bitset,
k,
outDistances,
outIndices,
@ -145,6 +147,7 @@ FlatIndex::query(Tensor<float, 2, true>& input,
!storeTransposed_, // is vectors row major?
input,
true, // input is row major
bitset,
k,
outDistances,
outIndices);
@ -154,6 +157,7 @@ FlatIndex::query(Tensor<float, 2, true>& input,
void
FlatIndex::query(Tensor<half, 2, true>& input,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
@ -167,6 +171,7 @@ FlatIndex::query(Tensor<half, 2, true>& input,
&normsHalf_,
input,
true, // input is row major
bitset,
k,
outDistances,
outIndices,
@ -179,6 +184,7 @@ FlatIndex::query(Tensor<half, 2, true>& input,
!storeTransposed_, // is vectors row major?
input,
true, // input is row major
bitset,
k,
outDistances,
outIndices,

View File

@ -53,12 +53,14 @@ class FlatIndex {
cudaStream_t stream);
void query(Tensor<float, 2, true>& vecs,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool exactDistance);
void query(Tensor<half, 2, true>& vecs,
Tensor<uint8_t, 1, true>& bitset,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,

View File

@ -157,7 +157,8 @@ IVFFlat::addCodeVectorsFromCpu(int listId,
int
IVFFlat::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
Tensor<long, 1, true>& indices) {
Tensor<long, 1, true>& indices,
Tensor<uint8_t, 1, true>& bitset) {
FAISS_ASSERT(vecs.getSize(0) == indices.getSize(0));
FAISS_ASSERT(vecs.getSize(1) == dim_);
@ -174,7 +175,7 @@ IVFFlat::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
listIds2d(mem, {vecs.getSize(0), 1}, stream);
auto listIds = listIds2d.view<1>({vecs.getSize(0)});
quantizer_->query(vecs, 1, listDistance2d, listIds2d, false);
quantizer_->query(vecs, bitset, 1, listDistance2d, listIds2d, false);
// Calculate residuals for these vectors, if needed
DeviceTensor<float, 2, true>
@ -326,6 +327,7 @@ IVFFlat::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
void
IVFFlat::query(Tensor<float, 2, true>& queries,
Tensor<uint8_t, 1, true>& bitset,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,
@ -352,6 +354,7 @@ IVFFlat::query(Tensor<float, 2, true>& queries,
// Find the `nprobe` closest lists; we can use int indices both
// internally and externally
quantizer_->query(queries,
bitset,
nprobe,
coarseDistances,
coarseIndices,

View File

@ -44,11 +44,14 @@ class IVFFlat : public IVFBase {
/// Returns the number of vectors successfully added. Vectors may
/// not be able to be added because they contain NaNs.
int classifyAndAddVectors(Tensor<float, 2, true>& vecs,
Tensor<long, 1, true>& indices);
Tensor<long, 1, true>& indices,
Tensor<uint8_t, 1, true>& bitset);
/// Find the approximate k nearest neigbors for `queries` against
/// our database
void query(Tensor<float, 2, true>& queries,
Tensor<uint8_t, 1, true>& bitset,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,

View File

@ -110,7 +110,8 @@ IVFPQ::setPrecomputedCodes(bool enable) {
int
IVFPQ::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
Tensor<long, 1, true>& indices) {
Tensor<long, 1, true>& indices,
Tensor<uint8_t, 1, true>& bitset) {
FAISS_ASSERT(vecs.getSize(0) == indices.getSize(0));
FAISS_ASSERT(vecs.getSize(1) == dim_);
@ -128,7 +129,7 @@ IVFPQ::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
DeviceTensor<int, 2, true> listIds2d(mem, {vecs.getSize(0), 1}, stream);
auto listIds = listIds2d.view<1>({vecs.getSize(0)});
quantizer_->query(vecs, 1, listDistance, listIds2d, false);
quantizer_->query(vecs, bitset, 1, listDistance, listIds2d, false);
// Copy the lists that we wish to append to back to the CPU
// FIXME: really this can be into pinned memory and a true async
@ -184,6 +185,7 @@ IVFPQ::classifyAndAddVectors(Tensor<float, 2, true>& vecs,
nullptr, // no precomputed norms
residualsTransposeView,
true, // residualsTransposeView is row major
bitset,
1,
closestSubQDistanceView,
closestSubQIndexView,
@ -506,6 +508,7 @@ IVFPQ::precomputeCodes_() {
void
IVFPQ::query(Tensor<float, 2, true>& queries,
Tensor<uint8_t, 1, true>& bitset,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,
@ -531,6 +534,7 @@ IVFPQ::query(Tensor<float, 2, true>& queries,
// Find the `nprobe` closest coarse centroids; we can use int
// indices both internally and externally
quantizer_->query(queries,
bitset,
nprobe,
coarseDistances,
coarseIndices,

View File

@ -52,11 +52,13 @@ class IVFPQ : public IVFBase {
/// Returns the number of vectors successfully added. Vectors may
/// not be able to be added because they contain NaNs.
int classifyAndAddVectors(Tensor<float, 2, true>& vecs,
Tensor<long, 1, true>& indices);
Tensor<long, 1, true>& indices,
Tensor<uint8_t, 1, true>& bitset);
/// Find the approximate k nearest neigbors for `queries` against
/// our database
void query(Tensor<float, 2, true>& queries,
Tensor<uint8_t, 1, true>& bitset,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,

View File

@ -159,9 +159,160 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
}
}
// With bitset included
// L2 + select kernel for k == 1, implements re-use of ||c||^2
template <typename T, int kRowsPerBlock, int kBlockSize>
__global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices) {
// Each block handles kRowsPerBlock rows of the distances (results)
Pair<T, int> threadMin[kRowsPerBlock];
__shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
T distance[kRowsPerBlock];
#pragma unroll
for (int i = 0; i < kRowsPerBlock; ++i) {
threadMin[i].k = Limits<T>::getMax();
threadMin[i].v = -1;
}
// blockIdx.x: which chunk of rows we are responsible for updating
int rowStart = blockIdx.x * kRowsPerBlock;
// FIXME: if we have exact multiples, don't need this
bool endRow = (blockIdx.x == gridDim.x - 1);
if (endRow) {
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
endRow = false;
}
}
if (endRow) {
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
for (int col = threadIdx.x; col < productDistances.getSize(1);
col += blockDim.x) {
if (!(bitset[col >> 3] & (0x1 << (col & 0x7)))) {
distance[0] = Math<T>::add(centroidDistances[col],
productDistances[row][col]);
if (Math<T>::lt(distance[0], threadMin[0].k)) {
threadMin[0].k = distance[0];
threadMin[0].v = col;
}
}
}
// Reduce within the block
threadMin[0] =
blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
threadMin[0], Min<Pair<T, int> >(), blockMin);
if (threadIdx.x == 0) {
outDistances[row][0] = threadMin[0].k;
outIndices[row][0] = threadMin[0].v;
}
// so we can use the shared memory again
__syncthreads();
threadMin[0].k = Limits<T>::getMax();
threadMin[0].v = -1;
}
} else {
for (int col = threadIdx.x; col < productDistances.getSize(1);
col += blockDim.x) {
T centroidDistance = centroidDistances[col];
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
distance[row] = productDistances[rowStart + row][col];
}
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
distance[row] = Math<T>::add(distance[row], centroidDistance);
}
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
if (Math<T>::lt(distance[row], threadMin[row].k)) {
threadMin[row].k = distance[row];
threadMin[row].v = col;
}
}
}
// Reduce within the block
blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
threadMin, Min<Pair<T, int> >(), blockMin);
if (threadIdx.x == 0) {
#pragma unroll
for (int row = 0; row < kRowsPerBlock; ++row) {
outDistances[rowStart + row][0] = threadMin[row].k;
outIndices[rowStart + row][0] = threadMin[row].v;
}
}
}
}
// With bitset included
// L2 + select kernel for k > 1, no re-use of ||c||^2
template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
__global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
Tensor<T, 1, true> centroidDistances,
Tensor<uint8_t, 1, true> bitset,
Tensor<T, 2, true> outDistances,
Tensor<int, 2, true> outIndices,
int k, T initK) {
// Each block handles a single row of the distances (results)
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
__shared__ T smemK[kNumWarps * NumWarpQ];
__shared__ int smemV[kNumWarps * NumWarpQ];
BlockSelect<T, int, false, Comparator<T>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(initK, -1, smemK, smemV, k);
int row = blockIdx.x;
// Whole warps must participate in the selection
int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
int i = threadIdx.x;
for (; i < limit; i += blockDim.x) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
T v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
heap.add(v, i);
}
}
if (i < productDistances.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
T v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
heap.addThreadQ(v, i);
}
}
heap.reduce();
for (int i = threadIdx.x; i < k; i += blockDim.x) {
outDistances[row][i] = smemK[i];
outIndices[row][i] = smemV[i];
}
}
template <typename T>
void runL2SelectMin(Tensor<T, 2, true>& productDistances,
Tensor<T, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
@ -181,7 +332,7 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
<<<grid, block, 0, stream>>>(productDistances, centroidDistances,
<<<grid, block, 0, stream>>>(productDistances, centroidDistances, bitset,
outDistances, outIndices);
} else {
auto grid = dim3(outDistances.getSize(0));
@ -194,28 +345,63 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
k, Limits<T>::getMax()); \
} while (0)
// block size 128 for everything <= 1024
if (k <= 32) {
RUN_L2_SELECT(128, 32, 2);
} else if (k <= 64) {
RUN_L2_SELECT(128, 64, 3);
} else if (k <= 128) {
RUN_L2_SELECT(128, 128, 3);
} else if (k <= 256) {
RUN_L2_SELECT(128, 256, 4);
} else if (k <= 512) {
RUN_L2_SELECT(128, 512, 8);
} else if (k <= 1024) {
RUN_L2_SELECT(128, 1024, 8);
#define RUN_L2_SELECT_BITSET(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
do { \
l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
<<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
bitset, outDistances, outIndices, \
k, Limits<T>::getMax()); \
} while (0)
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT(64, 2048, 8);
#endif
if (bitset.getSize(0) == 0) {
// block size 128 for everything <= 1024
if (k <= 32) {
RUN_L2_SELECT(128, 32, 2);
} else if (k <= 64) {
RUN_L2_SELECT(128, 64, 3);
} else if (k <= 128) {
RUN_L2_SELECT(128, 128, 3);
} else if (k <= 256) {
RUN_L2_SELECT(128, 256, 4);
} else if (k <= 512) {
RUN_L2_SELECT(128, 512, 8);
} else if (k <= 1024) {
RUN_L2_SELECT(128, 1024, 8);
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT(64, 2048, 8);
#endif
} else {
FAISS_ASSERT(false);
}
} else {
FAISS_ASSERT(false);
// With bitset
if (k <= 32) {
RUN_L2_SELECT_BITSET(128, 32, 2);
} else if (k <= 64) {
RUN_L2_SELECT_BITSET(128, 64, 3);
} else if (k <= 128) {
RUN_L2_SELECT_BITSET(128, 128, 3);
} else if (k <= 256) {
RUN_L2_SELECT_BITSET(128, 256, 4);
} else if (k <= 512) {
RUN_L2_SELECT_BITSET(128, 512, 8);
} else if (k <= 1024) {
RUN_L2_SELECT_BITSET(128, 1024, 8);
#if GPU_MAX_SELECTION_K >= 2048
} else if (k <= 2048) {
// smaller block for less shared memory
RUN_L2_SELECT_BITSET(64, 2048, 8);
#endif
} else {
FAISS_ASSERT(false);
}
}
}
@ -224,12 +410,14 @@ void runL2SelectMin(Tensor<T, 2, true>& productDistances,
void runL2SelectMin(Tensor<float, 2, true>& productDistances,
Tensor<float, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream) {
runL2SelectMin<float>(productDistances,
centroidDistances,
bitset,
outDistances,
outIndices,
k,
@ -238,12 +426,14 @@ void runL2SelectMin(Tensor<float, 2, true>& productDistances,
void runL2SelectMin(Tensor<half, 2, true>& productDistances,
Tensor<half, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream) {
runL2SelectMin<half>(productDistances,
centroidDistances,
bitset,
outDistances,
outIndices,
k,

View File

@ -26,4 +26,20 @@ void runL2SelectMin(Tensor<half, 2, true>& productDistances,
int k,
cudaStream_t stream);
void runL2SelectMin(Tensor<float, 2, true>& productDistances,
Tensor<float, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream);
void runL2SelectMin(Tensor<half, 2, true>& productDistances,
Tensor<half, 1, true>& centroidDistances,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
int k,
cudaStream_t stream);
} } // namespace

View File

@ -51,6 +51,7 @@ int main(int argc, char** argv) {
limitK = GPU_MAX_SELECTION_K;
}
faiss::gpu::DeviceTensor<uint8_t, 1, true> bitset(nullptr, {0});
for (int k = startK; k <= limitK; k *= 2) {
faiss::gpu::DeviceTensor<float, 2, true> gpuOutVal({FLAGS_rows, k});
faiss::gpu::DeviceTensor<int, 2, true> gpuOutInd({FLAGS_rows, k});
@ -60,7 +61,7 @@ int main(int argc, char** argv) {
faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd,
FLAGS_dir, k, 0);
} else {
faiss::gpu::runBlockSelect(gpuVal, gpuOutVal, gpuOutInd,
faiss::gpu::runBlockSelect(gpuVal, bitset, gpuOutVal, gpuOutInd,
FLAGS_dir, k, 0);
}
}

View File

@ -29,6 +29,8 @@ void testForSize(int rows, int cols, int k, bool dir, bool warp) {
}
}
faiss::gpu::DeviceTensor<uint8_t, 1, true> bitset(nullptr, {0});
// row -> (val -> idx)
std::unordered_map<int, std::vector<std::pair<int, float>>> hostOutValAndInd;
for (int r = 0; r < rows; ++r) {
@ -59,7 +61,8 @@ void testForSize(int rows, int cols, int k, bool dir, bool warp) {
if (warp) {
faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd, dir, k, 0);
} else {
faiss::gpu::runBlockSelect(gpuVal, gpuOutVal, gpuOutInd, dir, k, 0);
faiss::gpu::runBlockSelect(gpuVal, bitset, gpuOutVal, gpuOutInd, dir, k, 0);
}
// Copy back to CPU

View File

@ -43,6 +43,7 @@ BLOCK_SELECT_DECL(float, false, 2048);
#endif
void runBlockSelect(Tensor<float, 2, true>& in,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outK,
Tensor<int, 2, true>& outV,
bool dir, int k, cudaStream_t stream) {
@ -93,6 +94,7 @@ void runBlockSelect(Tensor<float, 2, true>& in,
void runBlockSelectPair(Tensor<float, 2, true>& inK,
Tensor<int, 2, true>& inV,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outK,
Tensor<int, 2, true>& outV,
bool dir, int k, cudaStream_t stream) {

View File

@ -43,6 +43,7 @@ BLOCK_SELECT_DECL(half, false, 2048);
#endif
void runBlockSelect(Tensor<half, 2, true>& in,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outK,
Tensor<int, 2, true>& outV,
bool dir, int k, cudaStream_t stream) {
@ -93,6 +94,7 @@ void runBlockSelect(Tensor<half, 2, true>& in,
void runBlockSelectPair(Tensor<half, 2, true>& inK,
Tensor<int, 2, true>& inV,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outK,
Tensor<int, 2, true>& outV,
bool dir, int k, cudaStream_t stream) {

View File

@ -110,24 +110,138 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
}
}
// Bitset included
template <typename K,
typename IndexType,
bool Dir,
int NumWarpQ,
int NumThreadQ,
int ThreadsPerBlock>
__global__ void blockSelect(Tensor<K, 2, true> in,
Tensor<uint8_t, 1, true> bitset,
Tensor<K, 2, true> outK,
Tensor<IndexType, 2, true> outV,
K initK,
IndexType initV,
int k) {
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
__shared__ K smemK[kNumWarps * NumWarpQ];
__shared__ IndexType smemV[kNumWarps * NumWarpQ];
BlockSelect<K, IndexType, Dir, Comparator<K>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(initK, initV, smemK, smemV, k);
// Grid is exactly sized to rows available
int row = blockIdx.x;
int i = threadIdx.x;
K* inStart = in[row][i].data();
// Whole warps must participate in the selection
int limit = utils::roundDown(in.getSize(1), kWarpSize);
for (; i < limit; i += ThreadsPerBlock) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
heap.add(*inStart, (IndexType) i);
inStart += ThreadsPerBlock;
}
}
// Handle last remainder fraction of a warp of elements
if (i < in.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
heap.addThreadQ(*inStart, (IndexType) i);
}
}
heap.reduce();
for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
outK[row][i] = smemK[i];
outV[row][i] = smemV[i];
}
}
template <typename K,
typename IndexType,
bool Dir,
int NumWarpQ,
int NumThreadQ,
int ThreadsPerBlock>
__global__ void blockSelectPair(Tensor<K, 2, true> inK,
Tensor<IndexType, 2, true> inV,
Tensor<uint8_t, 1, true> bitset,
Tensor<K, 2, true> outK,
Tensor<IndexType, 2, true> outV,
K initK,
IndexType initV,
int k) {
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
__shared__ K smemK[kNumWarps * NumWarpQ];
__shared__ IndexType smemV[kNumWarps * NumWarpQ];
BlockSelect<K, IndexType, Dir, Comparator<K>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(initK, initV, smemK, smemV, k);
// Grid is exactly sized to rows available
int row = blockIdx.x;
int i = threadIdx.x;
K* inKStart = inK[row][i].data();
IndexType* inVStart = inV[row][i].data();
// Whole warps must participate in the selection
int limit = utils::roundDown(inK.getSize(1), kWarpSize);
for (; i < limit; i += ThreadsPerBlock) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
heap.add(*inKStart, *inVStart);
inKStart += ThreadsPerBlock;
inVStart += ThreadsPerBlock;
}
}
// Handle last remainder fraction of a warp of elements
if (i < inK.getSize(1)) {
if (!(bitset[i >> 3] & (0x1 << (i & 0x7)))) {
heap.addThreadQ(*inKStart, *inVStart);
}
}
heap.reduce();
for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
outK[row][i] = smemK[i];
outV[row][i] = smemV[i];
}
}
void runBlockSelect(Tensor<float, 2, true>& in,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outKeys,
Tensor<int, 2, true>& outIndices,
bool dir, int k, cudaStream_t stream);
void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
Tensor<int, 2, true>& inIndices,
Tensor<uint8_t, 1, true>& bitset,
Tensor<float, 2, true>& outKeys,
Tensor<int, 2, true>& outIndices,
bool dir, int k, cudaStream_t stream);
void runBlockSelect(Tensor<half, 2, true>& in,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outKeys,
Tensor<int, 2, true>& outIndices,
bool dir, int k, cudaStream_t stream);
void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
Tensor<int, 2, true>& inIndices,
Tensor<uint8_t, 1, true>& bitset,
Tensor<half, 2, true>& outKeys,
Tensor<int, 2, true>& outIndices,
bool dir, int k, cudaStream_t stream);

View File

@ -13,6 +13,7 @@
#define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
Tensor<TYPE, 2, true>& in, \
Tensor<uint8_t, 1, true>& bitset, \
Tensor<TYPE, 2, true>& outK, \
Tensor<int, 2, true>& outV, \
bool dir, \
@ -22,15 +23,17 @@
extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
Tensor<TYPE, 2, true>& inK, \
Tensor<int, 2, true>& inV, \
Tensor<uint8_t, 1, true>& bitset, \
Tensor<TYPE, 2, true>& outK, \
Tensor<int, 2, true>& outV, \
bool dir, \
int k, \
cudaStream_t stream)
cudaStream_t stream);
#define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
Tensor<TYPE, 2, true>& in, \
Tensor<uint8_t, 1, true>& bitset, \
Tensor<TYPE, 2, true>& outK, \
Tensor<int, 2, true>& outV, \
bool dir, \
@ -52,14 +55,19 @@
auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
auto vInit = -1; \
\
blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
if (bitset.getSize(0) == 0) \
blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
else \
blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(in, bitset, outK, outV, kInit, vInit, k); \
CUDA_TEST_ERROR(); \
} \
\
void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
Tensor<TYPE, 2, true>& inK, \
Tensor<int, 2, true>& inV, \
Tensor<uint8_t, 1, true>& bitset, \
Tensor<TYPE, 2, true>& outK, \
Tensor<int, 2, true>& outV, \
bool dir, \
@ -79,16 +87,20 @@
auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
auto vInit = -1; \
\
blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
if (bitset.getSize(0) == 0) \
blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
else \
blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
<<<grid, block, 0, stream>>>(inK, inV, bitset, outK, outV, kInit, vInit, k); \
CUDA_TEST_ERROR(); \
}
#define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
in, outK, outV, dir, k, stream)
in, bitset, outK, outV, dir, k, stream)
#define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
inK, inV, outK, outV, dir, k, stream)
inK, inV, bitset, outK, outV, dir, k, stream)