Set the bitset of searching (#3858)

Signed-off-by: cqy <yaya645@126.com>
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/3916/head
cqy123456 2020-09-25 11:33:34 +08:00 committed by shengjun.li
parent 7785f44ef4
commit 7bfcec642f
49 changed files with 170 additions and 179 deletions

View File

@ -248,7 +248,7 @@ MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<idx_t>&
Status
ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
const query::VectorQueryPtr& vector_param, knowhere::VecIndexPtr& vec_index,
bool hybrid) {
const faiss::ConcurrentBitsetPtr& bitset, bool hybrid) {
TimeRecorder rc(LogOut("[%s][%ld] ExecutionEngineImpl::VecSearch", "search", 0));
if (vec_index == nullptr) {
@ -284,11 +284,10 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
} else {
dataset = knowhere::GenDataset(nq, vec_index->Dim(), query_vector.binary_data.data());
}
auto result = vec_index->Query(dataset, conf);
auto result = vec_index->Query(dataset, conf, bitset);
MapAndCopyResult(result, vec_index->GetUids(), nq, topk, context.query_result_->result_distances_.data(),
context.query_result_->result_ids_.data());
if (hybrid) {
// HybridUnset();
}
@ -341,7 +340,6 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
list->set(i);
}
}
vec_index->SetBlacklist(list);
auto& vector_param = context.query_ptr_->vectors.at(vector_placeholder);
if (!vector_param->query_vector.float_data.empty()) {
@ -350,7 +348,7 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
vector_param->nq = vector_param->query_vector.binary_data.size() * 8 / vec_index->Dim();
}
status = VecSearch(context, context.query_ptr_->vectors.at(vector_placeholder), vec_index);
status = VecSearch(context, context.query_ptr_->vectors.at(vector_placeholder), vec_index, list);
if (!status.ok()) {
return status;
}

View File

@ -43,7 +43,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
private:
Status
VecSearch(ExecutionEngineContext& context, const query::VectorQueryPtr& vector_param,
knowhere::VecIndexPtr& vec_index, bool hybrid = false);
knowhere::VecIndexPtr& vec_index, const faiss::ConcurrentBitsetPtr& bitset, bool hybrid = false);
knowhere::VecIndexPtr
CreateVecIndex(const std::string& index_name, knowhere::IndexMode mode);

View File

@ -105,7 +105,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -116,7 +116,6 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto all_num = rows * k;
auto p_id = static_cast<int64_t*>(malloc(all_num * sizeof(int64_t)));
auto p_dist = static_cast<float*>(malloc(all_num * sizeof(float)));
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
@ -125,7 +124,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
std::vector<float> distances;
distances.reserve(k);
index_->get_nns_by_vector(static_cast<const float*>(p_data) + i * dim, k, search_k, &result, &distances,
blacklist);
bitset);
int64_t result_num = result.size();
auto local_p_id = p_id + k * i;

View File

@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -40,7 +40,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
}
DatasetPtr
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -142,13 +142,13 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
// assign the metric type
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto i_distances = reinterpret_cast<int32_t*>(distances);
bin_flat_index->search(n, data, k, i_distances, labels, bitset_);
bin_flat_index->search(n, data, k, i_distances, labels, bitset);
// if hamming, it need transform int32 to float
if (bin_flat_index->metric_type == faiss::METRIC_Hamming) {

View File

@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;
@ -69,7 +69,8 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
protected:
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
protected:
std::mutex mutex_;

View File

@ -43,7 +43,7 @@ BinaryIVF::Load(const BinarySet& index_binary) {
}
DatasetPtr
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -59,7 +59,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
@ -126,15 +126,15 @@ BinaryIVF::GenParams(const Config& config) {
}
void
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
stdclock::time_point before = stdclock::now();
auto i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, data, k, i_distances, labels, bitset_);
index_->search(n, data, k, i_distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();

View File

@ -60,7 +60,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;
@ -76,7 +76,8 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
GenParams(const Config& config);
virtual void
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset);
protected:
std::mutex mutex_;

View File

@ -136,7 +136,7 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -153,7 +153,6 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<P> ret;
@ -166,7 +165,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
// } else {
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn(single_query, k, compare, blacklist);
ret = index_->searchKnn(single_query, k, compare, bitset);
while (ret.size() < k) {
ret.emplace_back(std::make_pair(-1, -1));

View File

@ -46,7 +46,7 @@ class IndexHNSW : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -95,7 +95,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
@ -108,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -223,11 +223,12 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
#endif
void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, data, k, distances, labels, bitset_);
index_->search(n, data, k, distances, labels, bitset);
}
} // namespace knowhere

View File

@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
#if 0
DatasetPtr
@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
protected:
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
protected:
std::mutex mutex_;

View File

@ -97,7 +97,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -115,7 +115,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
@ -296,7 +296,7 @@ IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];
@ -318,7 +318,8 @@ IVF::GenParams(const Config& config) {
}
void
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
@ -328,7 +329,7 @@ IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_
} else {
ivf_index->parallel_mode = 0;
}
ivf_index->search(n, data, k, distances, labels, bitset_);
ivf_index->search(n, data, k, distances, labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost

View File

@ -51,7 +51,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
#if 0
DatasetPtr
@ -86,7 +86,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&);
void
SealImpl() override;

View File

@ -73,7 +73,7 @@ NSG::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -87,15 +87,13 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
impl::SearchParams s_params;
s_params.search_length = config[IndexParams::search_length];
s_params.k = config[meta::TOPK];
{
std::lock_guard<std::mutex> lk(mutex_);
index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id,
s_params, blacklist);
s_params, bitset);
}
auto ret_ds = std::make_shared<Dataset>();

View File

@ -59,7 +59,7 @@ class NSG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override;
int64_t
Count() override;

View File

@ -79,7 +79,7 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -96,10 +96,9 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
}
auto real_index = dynamic_cast<faiss::IndexRHNSW*>(index_.get());
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
real_index->hnsw.efSearch = (config[IndexParams::ef]);
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, blacklist);
real_index->search(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);

View File

@ -52,7 +52,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -176,7 +176,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
}
DatasetPtr
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
SetParameters(config);
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);

View File

@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -46,7 +46,7 @@ class VecIndex : public Index {
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
virtual DatasetPtr
Query(const DatasetPtr& dataset, const Config& config) = 0;
Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0;
#if 0
virtual DatasetPtr
@ -144,9 +144,11 @@ class VecIndex : public Index {
protected:
IndexType index_type_ = "";
IndexMode index_mode_ = IndexMode::MODE_CPU;
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
std::vector<IDType> uids_;
int64_t index_size_ = -1;
private:
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
};
using VecIndexPtr = std::shared_ptr<VecIndex>;

View File

@ -104,13 +104,14 @@ GPUIDMAP::GetRawIds() {
}
void
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
ResScope rs(res_, gpu_id_);
// assign the metric type
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, data, k, distances, labels, bitset_);
index_->search(n, data, k, distances, labels, bitset);
}
void
@ -134,7 +135,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const C
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];

View File

@ -55,7 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&,
const faiss::ConcurrentBitsetPtr& bitset) override;
};
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;

View File

@ -137,7 +137,8 @@ GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
}
void
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
@ -152,7 +153,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
for (int64_t i = 0; i < n; i += block_size) {
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
device_index->search(search_size, reinterpret_cast<const float*>(data) + i * dim, k, distances + i * k,
labels + i * k, bitset_);
labels + i * k, bitset);
}
} else {
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");

View File

@ -51,7 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&,
const faiss::ConcurrentBitsetPtr& bitset) override;
};
using GPUIVFPtr = std::shared_ptr<GPUIVF>;

View File

@ -241,21 +241,21 @@ IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
}
void
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
if (gpu_mode_ == 2) {
GPUIVF::QueryImpl(n, data, k, distances, labels, config);
GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset);
// index_->search(n, (float*)data, k, distances, labels);
} else if (gpu_mode_ == 1) { // hybrid
auto gpu_id = quantizer_->gpu_id;
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, true);
IVF::QueryImpl(n, data, k, distances, labels, config);
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
} else {
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
}
} else if (gpu_mode_ == 0) {
IVF::QueryImpl(n, data, k, distances, labels, config);
IVF::QueryImpl(n, data, k, distances, labels, config, bitset);
}
}

View File

@ -88,7 +88,8 @@ class IVFSQHybrid : public GPUIVFSQ {
LoadImpl(const BinarySet&, const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&,
const faiss::ConcurrentBitsetPtr& bitset) override;
protected:
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU

View File

@ -138,7 +138,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
}
DatasetPtr
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -156,7 +156,7 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config);
QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -283,7 +283,7 @@ IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Con
res.resize(K * b_size);
const float* xq = data + batch_size * dim * i;
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config);
QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr);
for (int j = 0; j < b_size; ++j) {
auto& node = graph[batch_size * i + j];
@ -305,7 +305,8 @@ IVF_NM::GenParams(const Config& config) {
}
void
IVF_NM::QueryImpl(int64_t n, const float* query, int64_t k, float* distances, int64_t* labels, const Config& config) {
IVF_NM::QueryImpl(int64_t n, const float* query, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
@ -324,7 +325,7 @@ IVF_NM::QueryImpl(int64_t n, const float* query, int64_t k, float* distances, in
#endif
ivf_index->search_without_codes(n, reinterpret_cast<const float*>(query), data, prefix_sum, is_sq8, k, distances,
labels, bitset_);
labels, bitset);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost

View File

@ -51,7 +51,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
AddWithoutIds(const DatasetPtr&, const Config&) override;
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
#if 0
DatasetPtr
@ -86,7 +86,8 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
GenParams(const Config&);
virtual void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&,
const faiss::ConcurrentBitsetPtr& bitset);
void
SealImpl() override;

View File

@ -74,7 +74,7 @@ NSG_NM::Load(const BinarySet& index_binary) {
}
DatasetPtr
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
@ -89,8 +89,6 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = static_cast<int64_t*>(malloc(p_id_size));
auto p_dist = static_cast<float*>(malloc(p_dist_size));
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
impl::SearchParams s_params;
s_params.search_length = config[IndexParams::search_length];
s_params.k = config[meta::TOPK];
@ -98,7 +96,7 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
// index_->ori_data_ = (float*) data_.get();
index_->Search(reinterpret_cast<const float*>(p_data), reinterpret_cast<float*>(data_.get()), rows, dim,
topK, p_dist, p_id, s_params, blacklist);
topK, p_dist, p_id, s_params, bitset);
}
auto ret_ds = std::make_shared<Dataset>();

View File

@ -59,7 +59,7 @@ class NSG_NM : public VecIndex {
}
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override;
int64_t
Count() override;

View File

@ -118,7 +118,8 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) {
}
void
GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
const faiss::ConcurrentBitsetPtr& bitset) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
@ -132,7 +133,7 @@ GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances,
int64_t dim = device_index->d;
for (int64_t i = 0; i < n; i += block_size) {
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset_);
device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset);
}
} else {
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");

View File

@ -51,7 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex {
SerializeImpl(const IndexType&) override;
void
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&,
const faiss::ConcurrentBitsetPtr& bitset) override;
protected:
uint8_t* arranged_data;

View File

@ -53,7 +53,7 @@ TEST_P(AnnoyTest, annoy_basic) {
// null faiss index
{
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
@ -65,7 +65,7 @@ TEST_P(AnnoyTest, annoy_basic) {
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
/*
@ -104,11 +104,10 @@ TEST_P(AnnoyTest, annoy_delete) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, bitset);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
@ -200,7 +199,7 @@ TEST_P(AnnoyTest, annoy_serialize) {
index_->Load(binaryset);
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -52,7 +52,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
// null faiss index
{
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
@ -63,14 +63,14 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
ASSERT_TRUE(index_->GetRawIds() != nullptr);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
auto binaryset = index_->Serialize(conf);
auto new_index = std::make_shared<milvus::knowhere::BinaryIDMAP>();
new_index->Load(binaryset);
auto result2 = new_index->Query(query_dataset, conf);
auto result2 = new_index->Query(query_dataset, conf, nullptr);
AssertAnns(result2, nq, k);
// PrintResult(re_result, nq, k);
@ -78,9 +78,8 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf);
auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
// auto result4 = index_->SearchById(id_dataset, conf);
@ -107,7 +106,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
// serialize index
index_->Train(base_dataset, conf);
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
auto re_result = index_->Query(query_dataset, conf);
auto re_result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k);
EXPECT_EQ(index_->Count(), nb);
@ -126,7 +125,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
}

View File

@ -63,7 +63,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
// null faiss index
{
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
@ -72,7 +72,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
@ -80,13 +80,12 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
#if 0
auto result3 = index_->QueryById(id_dataset, conf);
auto result3 = index_->QueryById(id_dataset, conf, nullptr);
AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL);
auto result4 = index_->GetVectorById(xid_dataset, conf);
@ -145,7 +144,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}

View File

@ -67,7 +67,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
{
for (int i = 0; i < 3; ++i) {
auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf);
auto result = gpu_idx->Query(query_dataset, conf);
auto result = gpu_idx->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
@ -83,7 +83,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf);
auto gpu_idx = pair.first;
auto result = gpu_idx->Query(query_dataset, conf);
auto result = gpu_idx->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
@ -93,7 +93,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
hybrid_idx->Load(binaryset);
auto quantization = hybrid_idx->LoadQuantizer(quantizer_conf);
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
auto result = new_idx->Query(query_dataset, conf);
auto result = new_idx->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
@ -112,7 +112,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
hybrid_idx->Load(binaryset);
hybrid_idx->SetQuantizer(quantization);
auto result = hybrid_idx->Query(query_dataset, conf);
auto result = hybrid_idx->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
hybrid_idx->UnsetQuantizer();

View File

@ -74,7 +74,7 @@ TEST_F(GPURESTEST, copyandsearch) {
auto conf = ParamGenerator::GetInstance().Gen(index_type_);
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
index_->SetIndexSize(nb * dim * sizeof(float));
@ -88,7 +88,7 @@ TEST_F(GPURESTEST, copyandsearch) {
auto search_func = [&] {
// TimeRecorder tc("search&load");
for (int i = 0; i < search_count; ++i) {
search_idx->Query(query_dataset, conf);
search_idx->Query(query_dataset, conf, nullptr);
// if (i > search_count - 6 || i == 0)
// tc.RecordSection("search once");
}
@ -107,7 +107,7 @@ TEST_F(GPURESTEST, copyandsearch) {
milvus::knowhere::TimeRecorder tc("Basic");
milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config());
tc.RecordSection("Copy to gpu once");
search_idx->Query(query_dataset, conf);
search_idx->Query(query_dataset, conf, nullptr);
tc.RecordSection("Search once");
search_func();
tc.RecordSection("Search total cost");
@ -145,7 +145,7 @@ TEST_F(GPURESTEST, trainandsearch) {
};
auto search_stage = [&](milvus::knowhere::VecIndexPtr& search_idx) {
for (int i = 0; i < search_count; ++i) {
auto result = search_idx->Query(query_dataset, conf);
auto result = search_idx->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
}
};

View File

@ -78,7 +78,7 @@ TEST_P(HNSWTest, HNSW_basic) {
index_->Load(bs);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
}
@ -108,11 +108,10 @@ TEST_P(HNSWTest, HNSW_delete) {
index_->Load(bs);
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, bitset);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*

View File

@ -73,7 +73,7 @@ TEST_P(IDMAPTest, idmap_basic) {
// null faiss index
{
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
@ -84,7 +84,7 @@ TEST_P(IDMAPTest, idmap_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
ASSERT_TRUE(index_->GetRawIds() != nullptr);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
@ -98,7 +98,7 @@ TEST_P(IDMAPTest, idmap_basic) {
auto binaryset = index_->Serialize(conf);
auto new_index = std::make_shared<milvus::knowhere::IDMAP>();
new_index->Load(binaryset);
auto result2 = new_index->Query(query_dataset, conf);
auto result2 = new_index->Query(query_dataset, conf, nullptr);
AssertAnns(result2, nq, k);
// PrintResult(re_result, nq, k);
@ -114,9 +114,8 @@ TEST_P(IDMAPTest, idmap_basic) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf);
auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
#if 0
@ -153,7 +152,7 @@ TEST_P(IDMAPTest, idmap_serialize) {
#endif
}
auto re_result = index_->Query(query_dataset, conf);
auto re_result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k);
EXPECT_EQ(index_->Count(), nb);
@ -172,7 +171,7 @@ TEST_P(IDMAPTest, idmap_serialize) {
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
}
@ -192,7 +191,7 @@ TEST_P(IDMAPTest, idmap_copy) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
ASSERT_TRUE(index_->GetRawIds() != nullptr);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
@ -207,7 +206,7 @@ TEST_P(IDMAPTest, idmap_copy) {
// cpu to gpu
ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, conf));
auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
auto clone_result = clone_index->Query(query_dataset, conf);
auto clone_result = clone_index->Query(query_dataset, conf,nullptr);
AssertAnns(clone_result, nq, k);
ASSERT_THROW({ std::static_pointer_cast<milvus::knowhere::GPUIDMAP>(clone_index)->GetRawVectors(); },
milvus::knowhere::KnowhereException);
@ -221,7 +220,7 @@ TEST_P(IDMAPTest, idmap_copy) {
auto binary = clone_index->Serialize(conf);
clone_index->Load(binary);
auto new_result = clone_index->Query(query_dataset, conf);
auto new_result = clone_index->Query(query_dataset, conf, nullptr);
AssertAnns(new_result, nq, k);
// auto clone_gpu_idx = clone_index->Clone();
@ -230,7 +229,7 @@ TEST_P(IDMAPTest, idmap_copy) {
// gpu to cpu
auto host_index = milvus::knowhere::cloner::CopyGpuToCpu(clone_index, conf);
auto host_result = host_index->Query(query_dataset, conf);
auto host_result = host_index->Query(query_dataset, conf, nullptr);
AssertAnns(host_result, nq, k);
ASSERT_TRUE(std::static_pointer_cast<milvus::knowhere::IDMAP>(host_index)->GetRawVectors() != nullptr);
ASSERT_TRUE(std::static_pointer_cast<milvus::knowhere::IDMAP>(host_index)->GetRawIds() != nullptr);
@ -239,7 +238,7 @@ TEST_P(IDMAPTest, idmap_copy) {
auto device_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
auto new_device_index =
std::static_pointer_cast<milvus::knowhere::GPUIDMAP>(device_index)->CopyGpuToGpu(DEVICEID, conf);
auto device_result = new_device_index->Query(query_dataset, conf);
auto device_result = new_device_index->Query(query_dataset, conf, nullptr);
AssertAnns(device_result, nq, k);
}
}

View File

@ -104,7 +104,7 @@ TEST_P(IVFTest, ivf_basic_cpu) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
@ -127,9 +127,8 @@ TEST_P(IVFTest, ivf_basic_cpu) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf_);
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
// PrintResult(result, nq, k);
@ -163,7 +162,7 @@ TEST_P(IVFTest, ivf_basic_gpu) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
@ -171,9 +170,8 @@ TEST_P(IVFTest, ivf_basic_gpu) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf_);
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
// PrintResult(result, nq, k);
@ -210,7 +208,7 @@ TEST_P(IVFTest, ivf_serialize) {
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
}
}
@ -228,7 +226,7 @@ TEST_P(IVFTest, clone_test) {
/* set peseodo index size, avoid throw exception */
index_->SetIndexSize(nq * dim * sizeof(float));
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
@ -248,7 +246,7 @@ TEST_P(IVFTest, clone_test) {
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
EXPECT_NO_THROW({
auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config());
auto clone_result = clone_index->Query(query_dataset, conf_);
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
AssertEqual(result, clone_result);
std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl;
});
@ -267,7 +265,7 @@ TEST_P(IVFTest, clone_test) {
if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) {
EXPECT_NO_THROW({
auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, milvus::knowhere::Config());
auto clone_result = clone_index->Query(query_dataset, conf_);
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
AssertEqual(result, clone_result);
std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl;
});
@ -284,7 +282,7 @@ TEST_P(IVFTest, gpu_seal_test) {
}
assert(!xb.empty());
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
ASSERT_ANY_THROW(index_->Seal());
index_->Train(base_dataset, conf_);
@ -295,15 +293,15 @@ TEST_P(IVFTest, gpu_seal_test) {
/* set peseodo index size, avoid throw exception */
index_->SetIndexSize(nq * dim * sizeof(float));
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
fiu_init(0);
fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0);
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
fiu_disable("IVF.Search.throw_std_exception");
fiu_enable("IVF.Search.throw_faiss_exception", 1, nullptr, 0);
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
fiu_disable("IVF.Search.throw_faiss_exception");
auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config());
@ -344,7 +342,7 @@ TEST_P(IVFTest, invalid_gpu_source) {
fiu_disable("GPUIVF.SerializeImpl.throw_exception");
fiu_enable("GPUIVF.search_impl.invald_index", 1, nullptr, 0);
ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf));
ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf, nullptr));
fiu_disable("GPUIVF.search_impl.invald_index");
auto ivf_index = std::dynamic_pointer_cast<milvus::knowhere::GPUIVF>(index_);

View File

@ -100,7 +100,7 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) {
bs.Append(RAW_DATA, bptr);
index_->Load(bs);
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, k);
#ifdef MILVUS_GPU_VERSION
@ -108,7 +108,7 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) {
{
EXPECT_NO_THROW({
auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf_);
auto clone_result = clone_index->Query(query_dataset, conf_);
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
AssertAnns(clone_result, nq, k);
std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl;
});
@ -120,9 +120,8 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf_);
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
#ifdef MILVUS_GPU_VERSION

View File

@ -101,7 +101,7 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) {
SERIALIZE_AND_LOAD(index_);
auto result = index_->Query(query_dataset, conf_);
auto result = index_->Query(query_dataset, conf_, nullptr);
AssertAnns(result, nq, k);
auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) {
@ -118,7 +118,7 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) {
EXPECT_NO_THROW({
auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, conf_);
SERIALIZE_AND_LOAD(clone_index);
auto clone_result = clone_index->Query(query_dataset, conf_);
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
AssertEqual(result, clone_result);
std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl;
});
@ -128,9 +128,8 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) {
for (int64_t i = 0; i < nq; ++i) {
concurrent_bitset_ptr->set(i);
}
index_->SetBlacklist(concurrent_bitset_ptr);
auto result_bs_1 = index_->Query(query_dataset, conf_);
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump();

View File

@ -80,7 +80,7 @@ TEST_F(NSGInterfaceTest, basic_test) {
// untrained index
{
ASSERT_ANY_THROW(index_->Serialize(search_conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf, nullptr));
ASSERT_ANY_THROW(index_->Add(base_dataset, search_conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, search_conf));
}
@ -101,7 +101,7 @@ TEST_F(NSGInterfaceTest, basic_test) {
index_->Load(bs);
auto result = index_->Query(query_dataset, search_conf);
auto result = index_->Query(query_dataset, search_conf, nullptr);
AssertAnns(result, nq, k);
/* test NSG GPU train */
@ -122,7 +122,7 @@ TEST_F(NSGInterfaceTest, basic_test) {
new_index_1->Load(bs);
auto new_result_1 = new_index_1->Query(query_dataset, search_conf);
auto new_result_1 = new_index_1->Query(query_dataset, search_conf, nullptr);
AssertAnns(new_result_1, nq, k);
ASSERT_EQ(index_->Count(), nb);
@ -163,7 +163,7 @@ TEST_F(NSGInterfaceTest, delete_test) {
index_->Load(bs);
auto result = index_->Query(query_dataset, search_conf);
auto result = index_->Query(query_dataset, search_conf, nullptr);
AssertAnns(result, nq, k);
ASSERT_EQ(index_->Count(), nb);
@ -176,9 +176,6 @@ TEST_F(NSGInterfaceTest, delete_test) {
auto I_before = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
// search xq with delete
index_->SetBlacklist(bitset);
// Serialize and Load before Query
bs = index_->Serialize(search_conf);
@ -191,7 +188,7 @@ TEST_F(NSGInterfaceTest, delete_test) {
bs.Append(RAW_DATA, bptr);
index_->Load(bs);
auto result_after = index_->Query(query_dataset, search_conf);
auto result_after = index_->Query(query_dataset, search_conf, bitset);
AssertAnns(result_after, nq, k, CheckMode::CHECK_NOT_EQUAL);
auto I_after = result_after->Get<int64_t*>(milvus::knowhere::meta::IDS);

View File

@ -52,7 +52,7 @@ TEST_P(RHNSWFlatTest, HNSW_basic) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
// Serialize and Load before Query
@ -62,7 +62,7 @@ TEST_P(RHNSWFlatTest, HNSW_basic) {
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
auto result2 = tmp_index->Query(query_dataset, conf, nullptr);
// AssertAnns(result2, nq, k);
}
@ -79,11 +79,10 @@ TEST_P(RHNSWFlatTest, HNSW_delete) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, bitset);
// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
@ -152,7 +151,7 @@ TEST_P(RHNSWFlatTest, HNSW_serialize) {
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
auto result = new_idx->Query(query_dataset, conf, nullptr);
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -54,14 +54,14 @@ TEST_P(RHNSWPQTest, HNSW_basic) {
// Serialize and Load before Query
milvus::knowhere::BinarySet bs = index_->Serialize(conf);
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWPQ>();
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
auto result2 = tmp_index->Query(query_dataset, conf, nullptr);
// AssertAnns(result2, nq, k);
}
@ -78,11 +78,10 @@ TEST_P(RHNSWPQTest, HNSW_delete) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, bitset);
// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
@ -142,7 +141,7 @@ TEST_P(RHNSWPQTest, HNSW_serialize) {
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
auto result = new_idx->Query(query_dataset, conf, nullptr);
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -55,14 +55,14 @@ TEST_P(RHNSWSQ8Test, HNSW_basic) {
// Serialize and Load before Query
milvus::knowhere::BinarySet bs = index_->Serialize(conf);
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
auto tmp_index = std::make_shared<milvus::knowhere::IndexRHNSWSQ>();
tmp_index->Load(bs);
auto result2 = tmp_index->Query(query_dataset, conf);
auto result2 = tmp_index->Query(query_dataset, conf, nullptr);
// AssertAnns(result2, nq, k);
}
@ -79,11 +79,10 @@ TEST_P(RHNSWSQ8Test, HNSW_delete) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
auto result1 = index_->Query(query_dataset, conf, nullptr);
// AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
auto result2 = index_->Query(query_dataset, conf, bitset);
// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
@ -143,7 +142,7 @@ TEST_P(RHNSWSQ8Test, HNSW_serialize) {
new_idx->Load(binaryset);
EXPECT_EQ(new_idx->Count(), nb);
EXPECT_EQ(new_idx->Dim(), dim);
auto result = new_idx->Query(query_dataset, conf);
auto result = new_idx->Query(query_dataset, conf, nullptr);
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}

View File

@ -68,7 +68,7 @@ TEST_P(SPTAGTest, sptag_basic) {
index_->BuildAll(base_dataset, conf);
// index_->Add(base_dataset, conf);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
{
@ -100,7 +100,7 @@ TEST_P(SPTAGTest, sptag_serialize) {
auto binaryset = index_->Serialize();
auto new_index = std::make_shared<milvus::knowhere::CPUSPTAGRNG>(IndexType);
new_index->Load(binaryset);
auto result = new_index->Query(query_dataset, conf);
auto result = new_index->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
PrintResult(result, nq, k);
ASSERT_EQ(new_index->Count(), nb);
@ -136,7 +136,7 @@ TEST_P(SPTAGTest, sptag_serialize) {
auto new_index = std::make_shared<milvus::knowhere::CPUSPTAGRNG>(IndexType);
new_index->Load(load_data_list);
auto result = new_index->Query(query_dataset, conf);
auto result = new_index->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
PrintResult(result, nq, k);
}

View File

@ -82,7 +82,7 @@ TEST_P(VecIndexTest, basic) {
EXPECT_EQ(index_->index_type(), index_type_);
EXPECT_EQ(index_->index_mode(), index_mode_);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
PrintResult(result, nq, k);
}
@ -93,7 +93,7 @@ TEST_P(VecIndexTest, serialize) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->index_type(), index_type_);
EXPECT_EQ(index_->index_mode(), index_mode_);
auto result = index_->Query(query_dataset, conf);
auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
auto binaryset = index_->Serialize();
@ -103,7 +103,7 @@ TEST_P(VecIndexTest, serialize) {
EXPECT_EQ(index_->Count(), new_index->Count());
EXPECT_EQ(index_->index_type(), new_index->index_type());
EXPECT_EQ(index_->index_mode(), new_index->index_mode());
auto new_result = new_index_->Query(query_dataset, conf);
auto new_result = new_index_->Query(query_dataset, conf, nullptr);
AssertAnns(new_result, nq, conf[milvus::knowhere::meta::TOPK]);
}