mirror of https://github.com/milvus-io/milvus.git
Single query by distance (#4648)
* IndexFlat::QueryByDistance unittest passed Signed-off-by: cmli <chengming.li@zilliz.com> * change the datastructure of DynamicResultSet Signed-off-by: cmli <chengming.li@zilliz.com> * unittest of IndexFlat::QueryByDistance and IndexBinaryFalt::QueryByDistance has passed Signed-off-by: cmli <chengming.li@zilliz.com> * update the datastructure of DynamicResultSet... Signed-off-by: cmli <chengming.li@zilliz.com> * update the interface and the data structure of dynamic result set, compile passed Signed-off-by: cmli <chengming.li@zilliz.com> * fix unittest, add test for new data structure, to be debug... Signed-off-by: cmli <chengming.li@zilliz.com> * unittest passed Signed-off-by: cmli <chengming.li@zilliz.com> * add MapUids Signed-off-by: cmli <chengming.li@zilliz.com> * update code by the advise from review Signed-off-by: cmli <chengming.li@zilliz.com> Co-authored-by: cmli <chengming.li@zilliz.com>pull/4659/head
parent
e644680dd6
commit
d64b1dd165
|
@ -57,6 +57,7 @@ set(vector_index_srcs
|
||||||
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||||
|
knowhere/index/vector_index/helpers/DynamicResultSet.cpp
|
||||||
knowhere/index/vector_index/impl/nsg/Distance.cpp
|
knowhere/index/vector_index/impl/nsg/Distance.cpp
|
||||||
knowhere/index/vector_index/impl/nsg/NSG.cpp
|
knowhere/index/vector_index/impl/nsg/NSG.cpp
|
||||||
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
|
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
|
||||||
|
|
|
@ -65,6 +65,38 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const fa
|
||||||
return ret_ds;
|
return ret_ds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DynamicResultSegment
|
||||||
|
BinaryIDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
|
||||||
|
const faiss::BitsetView& bitset) {
|
||||||
|
if (!index_) {
|
||||||
|
KNOWHERE_THROW_MSG("index not initialize");
|
||||||
|
}
|
||||||
|
GET_TENSOR_DATA(dataset)
|
||||||
|
if (rows != 1) {
|
||||||
|
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto default_type = index_->metric_type;
|
||||||
|
if (config.contains(Metric::TYPE)) {
|
||||||
|
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||||
|
}
|
||||||
|
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||||
|
DynamicResultSegment result;
|
||||||
|
auto radius = config[IndexParams::range_search_radius].get<int>();
|
||||||
|
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||||
|
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||||
|
: 16384;
|
||||||
|
auto real_idx = dynamic_cast<faiss::IndexBinaryFlat*>(index_.get());
|
||||||
|
if (real_idx == nullptr) {
|
||||||
|
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexBinaryFlat type!");
|
||||||
|
}
|
||||||
|
real_idx->range_search(rows, reinterpret_cast<const uint8_t*>(p_data), radius, res, buffer_size, bitset);
|
||||||
|
ExchangeDataset(result, res);
|
||||||
|
MapUids(result);
|
||||||
|
index_->metric_type = default_type;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
BinaryIDMAP::Count() {
|
BinaryIDMAP::Count() {
|
||||||
if (!index_) {
|
if (!index_) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||||
#include "knowhere/index/vector_index/VecIndex.h"
|
#include "knowhere/index/vector_index/VecIndex.h"
|
||||||
|
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace knowhere {
|
namespace knowhere {
|
||||||
|
@ -46,6 +47,9 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
||||||
DatasetPtr
|
DatasetPtr
|
||||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||||
|
|
||||||
|
DynamicResultSegment
|
||||||
|
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
Count() override;
|
Count() override;
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,41 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::B
|
||||||
return ret_ds;
|
return ret_ds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DynamicResultSegment
|
||||||
|
IDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
|
||||||
|
const faiss::BitsetView& bitset) {
|
||||||
|
if (!index_) {
|
||||||
|
KNOWHERE_THROW_MSG("index not initialize");
|
||||||
|
}
|
||||||
|
GET_TENSOR_DATA(dataset)
|
||||||
|
if (rows != 1) {
|
||||||
|
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto default_type = index_->metric_type;
|
||||||
|
if (config.contains(Metric::TYPE)) {
|
||||||
|
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||||
|
}
|
||||||
|
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||||
|
DynamicResultSegment result;
|
||||||
|
auto radius = config[IndexParams::range_search_radius].get<float>();
|
||||||
|
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||||
|
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||||
|
: 16384;
|
||||||
|
auto real_idx = dynamic_cast<faiss::IndexFlat*>(index_.get());
|
||||||
|
if (real_idx == nullptr) {
|
||||||
|
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexFlat type!");
|
||||||
|
}
|
||||||
|
if (index_->metric_type == faiss::MetricType::METRIC_L2) {
|
||||||
|
radius *= radius;
|
||||||
|
}
|
||||||
|
real_idx->range_search(rows, reinterpret_cast<const float*>(p_data), radius, res, buffer_size, bitset);
|
||||||
|
ExchangeDataset(result, res);
|
||||||
|
MapUids(result);
|
||||||
|
index_->metric_type = default_type;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
IDMAP::Count() {
|
IDMAP::Count() {
|
||||||
if (!index_) {
|
if (!index_) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||||
#include "knowhere/index/vector_index/VecIndex.h"
|
#include "knowhere/index/vector_index/VecIndex.h"
|
||||||
|
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace knowhere {
|
namespace knowhere {
|
||||||
|
@ -45,6 +46,9 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
||||||
DatasetPtr
|
DatasetPtr
|
||||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||||
|
|
||||||
|
DynamicResultSegment
|
||||||
|
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
Count() override;
|
Count() override;
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "knowhere/index/Index.h"
|
#include "knowhere/index/Index.h"
|
||||||
#include "knowhere/index/IndexType.h"
|
#include "knowhere/index/IndexType.h"
|
||||||
#include "knowhere/index/vector_index/Statistics.h"
|
#include "knowhere/index/vector_index/Statistics.h"
|
||||||
|
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace knowhere {
|
namespace knowhere {
|
||||||
|
@ -94,6 +95,23 @@ class VecIndex : public Index {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
MapUids(DynamicResultSegment& milvus_dataset) {
|
||||||
|
if (uids_) {
|
||||||
|
for (auto& mrspr : milvus_dataset) {
|
||||||
|
for (auto j = 0; j < mrspr->buffers.size(); ++j) {
|
||||||
|
auto buf = mrspr->buffers[j];
|
||||||
|
auto len = j + 1 == mrspr->buffers.size() ? mrspr->wp : mrspr->buffer_size;
|
||||||
|
for (auto i = 0; i < len; ++i) {
|
||||||
|
if (buf.ids[i] >= 0) {
|
||||||
|
buf.ids[i] = uids_->at(buf.ids[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t
|
size_t
|
||||||
UidsSize() {
|
UidsSize() {
|
||||||
return uids_ ? uids_->size() * sizeof(IDType) : 0;
|
return uids_ ? uids_->size() * sizeof(IDType) : 0;
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
#include <src/index/knowhere/knowhere/common/Exception.h>
|
||||||
|
#include <src/index/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.h>
|
||||||
|
#include <src/index/thirdparty/faiss/impl/AuxIndexStructures.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace milvus {
|
||||||
|
namespace knowhere {
|
||||||
|
|
||||||
|
/***********************************************************************
|
||||||
|
* DynamicResultSet
|
||||||
|
***********************************************************************/
|
||||||
|
|
||||||
|
void
|
||||||
|
DynamicResultSet::do_alloction() {
|
||||||
|
if (count <= 0) {
|
||||||
|
KNOWHERE_THROW_MSG("DynamicResultSet::do_alloction failed because of count <= 0");
|
||||||
|
}
|
||||||
|
labels = std::shared_ptr<idx_t[]>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||||
|
distances = std::shared_ptr<float[]>(new float[count], std::default_delete<float[]>());
|
||||||
|
// labels = std::make_shared<idx_t []>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||||
|
// distances = std::make_shared<float []>(new float[count], std::default_delete<float[]>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
DynamicResultSet::do_sort(ResultSetPostProcessType postProcessType) {
|
||||||
|
if (postProcessType == ResultSetPostProcessType::SortAsc) {
|
||||||
|
quick_sort<true>(0, count);
|
||||||
|
} else if (postProcessType == ResultSetPostProcessType::SortDesc) {
|
||||||
|
quick_sort<false>(0, count);
|
||||||
|
} else {
|
||||||
|
KNOWHERE_THROW_MSG("invalid sort type!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool asc>
|
||||||
|
void
|
||||||
|
DynamicResultSet::quick_sort(size_t lp, size_t rp) {
|
||||||
|
auto len = rp - lp;
|
||||||
|
if (len <= 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto pvot = lp + (len >> 1);
|
||||||
|
size_t low = lp;
|
||||||
|
size_t high = rp - 1;
|
||||||
|
auto pids = labels.get();
|
||||||
|
auto pdis = distances.get();
|
||||||
|
std::swap(pdis[pvot], pdis[high]);
|
||||||
|
std::swap(pids[pvot], pids[high]);
|
||||||
|
if (asc) {
|
||||||
|
while (low < high) {
|
||||||
|
while (low < high && pdis[low] <= pdis[high]) {
|
||||||
|
low++;
|
||||||
|
}
|
||||||
|
if (low == high) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::swap(pdis[low], pdis[high]);
|
||||||
|
std::swap(pids[low], pids[high]);
|
||||||
|
high--;
|
||||||
|
while (low < high && pdis[high] >= pdis[low]) {
|
||||||
|
high--;
|
||||||
|
}
|
||||||
|
if (low == high) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::swap(pdis[low], pdis[high]);
|
||||||
|
std::swap(pids[low], pids[high]);
|
||||||
|
low++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (low < high) {
|
||||||
|
while (low < high && pdis[low] >= pdis[high]) {
|
||||||
|
low++;
|
||||||
|
}
|
||||||
|
if (low == high) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::swap(pdis[low], pdis[high]);
|
||||||
|
std::swap(pids[low], pids[high]);
|
||||||
|
high--;
|
||||||
|
while (low < high && pdis[high] <= pdis[low]) {
|
||||||
|
high--;
|
||||||
|
}
|
||||||
|
if (low == high) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::swap(pdis[low], pdis[high]);
|
||||||
|
std::swap(pids[low], pids[high]);
|
||||||
|
low++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quick_sort<asc>(lp, low);
|
||||||
|
quick_sort<asc>(low, rp);
|
||||||
|
}
|
||||||
|
|
||||||
|
DynamicResultSet
|
||||||
|
DynamicResultCollector::Merge(size_t limit, ResultSetPostProcessType postProcessType) {
|
||||||
|
if (limit <= 0) {
|
||||||
|
KNOWHERE_THROW_MSG("limit must > 0!");
|
||||||
|
}
|
||||||
|
DynamicResultSet ret;
|
||||||
|
auto seg_num = seg_results.size();
|
||||||
|
std::vector<size_t> boundaries(seg_num + 1, 0);
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (auto i = 0; i < seg_num; ++i) {
|
||||||
|
for (auto& pseg : seg_results[i]) {
|
||||||
|
boundaries[i] += (pseg->buffer_size * pseg->buffers.size() - pseg->buffer_size + pseg->wp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0, ofs = 0; i <= seg_num; ++i) {
|
||||||
|
auto bn = boundaries[i];
|
||||||
|
boundaries[i] = ofs;
|
||||||
|
ofs += bn;
|
||||||
|
// boundaries[i] += boundaries[i - 1];
|
||||||
|
}
|
||||||
|
ret.count = boundaries[seg_num] <= limit ? boundaries[seg_num] : limit;
|
||||||
|
ret.do_alloction();
|
||||||
|
|
||||||
|
// abandon redundancy answers randomly
|
||||||
|
// abandon strategy: keep the top limit sequentially
|
||||||
|
int32_t pos = 1;
|
||||||
|
for (int i = 1; i < boundaries.size(); ++i) {
|
||||||
|
if (boundaries[i] >= ret.count) {
|
||||||
|
pos = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos--; // last segment id
|
||||||
|
// full copy
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (auto i = 0; i < pos - 1; ++i) {
|
||||||
|
for (auto& pseg : seg_results[i]) {
|
||||||
|
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||||
|
pseg->copy_range(0, len, ret.labels.get() + boundaries[i], ret.distances.get() + boundaries[i]);
|
||||||
|
boundaries[i] += len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// partial copy
|
||||||
|
auto last_len = ret.count - boundaries[pos];
|
||||||
|
for (auto& pseg : seg_results[pos]) {
|
||||||
|
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||||
|
auto ncopy = last_len > len ? len : last_len;
|
||||||
|
pseg->copy_range(0, ncopy, ret.labels.get() + boundaries[pos], ret.distances.get() + boundaries[pos]);
|
||||||
|
boundaries[pos] += ncopy;
|
||||||
|
last_len -= ncopy;
|
||||||
|
if (last_len <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (postProcessType != ResultSetPostProcessType::None) {
|
||||||
|
ret.do_sort(postProcessType);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
DynamicResultCollector::Append(milvus::knowhere::DynamicResultSegment&& seg_result) {
|
||||||
|
seg_results.push_back(std::move(seg_result));
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset) {
|
||||||
|
for (auto& prspr : faiss_dataset) {
|
||||||
|
auto mrspr = new DynamicResultFragment(prspr->res->buffer_size);
|
||||||
|
mrspr->wp = prspr->wp;
|
||||||
|
mrspr->buffers.resize(prspr->buffers.size());
|
||||||
|
for (auto i = 0; i < prspr->buffers.size(); ++i) {
|
||||||
|
mrspr->buffers[i].ids = prspr->buffers[i].ids;
|
||||||
|
mrspr->buffers[i].dis = prspr->buffers[i].dis;
|
||||||
|
prspr->buffers[i].ids = nullptr;
|
||||||
|
prspr->buffers[i].dis = nullptr;
|
||||||
|
}
|
||||||
|
delete prspr->res;
|
||||||
|
milvus_dataset.push_back(mrspr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace knowhere
|
||||||
|
} // namespace milvus
|
|
@ -0,0 +1,60 @@
|
||||||
|
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "faiss/impl/AuxIndexStructures.h"
|
||||||
|
#include "knowhere/common/Typedef.h"
|
||||||
|
|
||||||
|
namespace milvus {
|
||||||
|
namespace knowhere {
|
||||||
|
|
||||||
|
enum class ResultSetPostProcessType { None = 0, SortDesc, SortAsc };
|
||||||
|
|
||||||
|
using idx_t = int64_t;
|
||||||
|
struct DynamicResultSet {
|
||||||
|
std::shared_ptr<idx_t[]> labels; /// result for query i is labels[lims[i]:lims[i + 1]]
|
||||||
|
std::shared_ptr<float[]> distances; /// corresponding distances, not sorted
|
||||||
|
size_t count; /// size of the result buffer's size, when reaches this size, auto start a new buffer
|
||||||
|
void
|
||||||
|
do_alloction();
|
||||||
|
void
|
||||||
|
do_sort(ResultSetPostProcessType postProcessType = ResultSetPostProcessType::SortAsc);
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <bool asc>
|
||||||
|
void
|
||||||
|
quick_sort(size_t lp, size_t rp);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef faiss::BufferList DynamicResultFragment;
|
||||||
|
typedef DynamicResultFragment* DynamicResultFragmentPtr;
|
||||||
|
typedef std::vector<DynamicResultFragmentPtr> DynamicResultSegment;
|
||||||
|
|
||||||
|
struct DynamicResultCollector {
|
||||||
|
public:
|
||||||
|
DynamicResultSet
|
||||||
|
Merge(size_t limit = 10000, ResultSetPostProcessType postProcessType = ResultSetPostProcessType::None);
|
||||||
|
void
|
||||||
|
Append(DynamicResultSegment&& seg_result);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<DynamicResultSegment> seg_results; /// unmerged results of every segments
|
||||||
|
};
|
||||||
|
|
||||||
|
void
|
||||||
|
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset);
|
||||||
|
|
||||||
|
} // namespace knowhere
|
||||||
|
} // namespace milvus
|
|
@ -28,6 +28,10 @@ constexpr const char* DEVICEID = "gpu_id";
|
||||||
}; // namespace meta
|
}; // namespace meta
|
||||||
|
|
||||||
namespace IndexParams {
|
namespace IndexParams {
|
||||||
|
// Range Search Params
|
||||||
|
constexpr const char* range_search_radius = "range_search_radius";
|
||||||
|
constexpr const char* range_search_buffer_size = "range_search_buffer_size";
|
||||||
|
|
||||||
// IVF Params
|
// IVF Params
|
||||||
constexpr const char* nprobe = "nprobe";
|
constexpr const char* nprobe = "nprobe";
|
||||||
constexpr const char* nlist = "nlist";
|
constexpr const char* nlist = "nlist";
|
||||||
|
|
|
@ -126,7 +126,17 @@ void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
|
||||||
RangeSearchResult *result,
|
RangeSearchResult *result,
|
||||||
const BitsetView& bitset) const
|
const BitsetView& bitset) const
|
||||||
{
|
{
|
||||||
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
|
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void IndexBinaryFlat::range_search(faiss::IndexBinary::idx_t n,
|
||||||
|
const uint8_t* x,
|
||||||
|
int radius,
|
||||||
|
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const faiss::BitsetView& bitset)
|
||||||
|
{
|
||||||
|
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace faiss
|
} // namespace faiss
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <faiss/IndexBinary.h>
|
#include <faiss/IndexBinary.h>
|
||||||
|
#include <faiss/impl/AuxIndexStructures.h>
|
||||||
|
|
||||||
namespace faiss {
|
namespace faiss {
|
||||||
|
|
||||||
|
@ -45,6 +46,11 @@ struct IndexBinaryFlat : IndexBinary {
|
||||||
RangeSearchResult *result,
|
RangeSearchResult *result,
|
||||||
const BitsetView& bitset = nullptr) const override;
|
const BitsetView& bitset = nullptr) const override;
|
||||||
|
|
||||||
|
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||||
|
std::vector<RangeSearchPartialResult*> &result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView& bitset = nullptr); // const override
|
||||||
|
|
||||||
void reconstruct(idx_t key, uint8_t *recons) const override;
|
void reconstruct(idx_t key, uint8_t *recons) const override;
|
||||||
|
|
||||||
/** Remove some ids. Note that because of the indexing structure,
|
/** Remove some ids. Note that because of the indexing structure,
|
||||||
|
|
|
@ -93,19 +93,28 @@ void IndexFlat::range_search (idx_t n, const float *x, float radius,
|
||||||
RangeSearchResult *result,
|
RangeSearchResult *result,
|
||||||
const BitsetView& bitset) const
|
const BitsetView& bitset) const
|
||||||
{
|
{
|
||||||
switch (metric_type) {
|
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||||
case METRIC_INNER_PRODUCT:
|
|
||||||
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
|
||||||
radius, result);
|
|
||||||
break;
|
|
||||||
case METRIC_L2:
|
|
||||||
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
FAISS_THROW_MSG("metric type not supported");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IndexFlat::range_search(faiss::Index::idx_t n,
|
||||||
|
const float* x,
|
||||||
|
float radius,
|
||||||
|
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const faiss::BitsetView& bitset) {
|
||||||
|
|
||||||
|
switch (metric_type) {
|
||||||
|
case METRIC_INNER_PRODUCT:
|
||||||
|
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
||||||
|
radius, result, buffer_size, bitset);
|
||||||
|
break;
|
||||||
|
case METRIC_L2:
|
||||||
|
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result, buffer_size, bitset);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
FAISS_THROW_MSG("metric type not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void IndexFlat::compute_distance_subset (
|
void IndexFlat::compute_distance_subset (
|
||||||
idx_t n,
|
idx_t n,
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <faiss/Index.h>
|
#include <faiss/Index.h>
|
||||||
|
#include <faiss/impl/AuxIndexStructures.h>
|
||||||
|
|
||||||
|
|
||||||
namespace faiss {
|
namespace faiss {
|
||||||
|
@ -50,6 +51,14 @@ struct IndexFlat: Index {
|
||||||
RangeSearchResult* result,
|
RangeSearchResult* result,
|
||||||
const BitsetView& bitset = nullptr) const override;
|
const BitsetView& bitset = nullptr) const override;
|
||||||
|
|
||||||
|
void range_search(
|
||||||
|
idx_t n,
|
||||||
|
const float* x,
|
||||||
|
float radius,
|
||||||
|
std::vector<RangeSearchPartialResult*> &result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView& bitset = nullptr); // const override
|
||||||
|
|
||||||
void reconstruct(idx_t key, float* recons) const override;
|
void reconstruct(idx_t key, float* recons) const override;
|
||||||
|
|
||||||
/** compute distance with a subset of vectors
|
/** compute distance with a subset of vectors
|
||||||
|
|
|
@ -48,9 +48,12 @@ void RangeSearchResult::do_allocation () {
|
||||||
}
|
}
|
||||||
|
|
||||||
RangeSearchResult::~RangeSearchResult () {
|
RangeSearchResult::~RangeSearchResult () {
|
||||||
delete [] labels;
|
if (labels)
|
||||||
delete [] distances;
|
delete [] labels;
|
||||||
delete [] lims;
|
if (distances)
|
||||||
|
delete [] distances;
|
||||||
|
if (lims)
|
||||||
|
delete [] lims;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,8 +74,10 @@ BufferList::BufferList (size_t buffer_size):
|
||||||
BufferList::~BufferList ()
|
BufferList::~BufferList ()
|
||||||
{
|
{
|
||||||
for (int i = 0; i < buffers.size(); i++) {
|
for (int i = 0; i < buffers.size(); i++) {
|
||||||
delete [] buffers[i].ids;
|
if (buffers[i].ids)
|
||||||
delete [] buffers[i].dis;
|
delete [] buffers[i].ids;
|
||||||
|
if (buffers[i].dis)
|
||||||
|
delete [] buffers[i].dis;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -807,13 +807,15 @@ void knn_L2sqr_by_idx (const float * x,
|
||||||
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
||||||
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
||||||
*/
|
*/
|
||||||
template <bool compute_l2>
|
template <bool compute_l2>
|
||||||
static void range_search_blas (
|
static void range_search_blas (
|
||||||
const float * x,
|
const float * x,
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *result)
|
std::vector<RangeSearchPartialResult*> &res,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
// BLAS does not like empty matrices
|
// BLAS does not like empty matrices
|
||||||
|
@ -837,13 +839,13 @@ static void range_search_blas (
|
||||||
fvec_norms_L2sqr (y_norms, y, d, ny);
|
fvec_norms_L2sqr (y_norms, y, d, ny);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector <RangeSearchPartialResult *> partial_results;
|
|
||||||
|
|
||||||
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
||||||
size_t j1 = j0 + bs_y;
|
size_t j1 = j0 + bs_y;
|
||||||
if (j1 > ny) j1 = ny;
|
if (j1 > ny) j1 = ny;
|
||||||
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||||
partial_results.push_back (pres);
|
tmp_res->buffer_size = buffer_size;
|
||||||
|
RangeSearchPartialResult * pres = new RangeSearchPartialResult (tmp_res);
|
||||||
|
res.push_back (pres);
|
||||||
|
|
||||||
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
||||||
size_t i1 = i0 + bs_x;
|
size_t i1 = i0 + bs_x;
|
||||||
|
@ -867,14 +869,16 @@ static void range_search_blas (
|
||||||
|
|
||||||
for (size_t j = j0; j < j1; j++) {
|
for (size_t j = j0; j < j1; j++) {
|
||||||
float ip = *ip_line++;
|
float ip = *ip_line++;
|
||||||
if (compute_l2) {
|
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||||
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
if (compute_l2) {
|
||||||
if (dis < radius) {
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
||||||
qres.add (dis, j);
|
if (dis < radius) {
|
||||||
}
|
qres.add (dis, j);
|
||||||
} else {
|
}
|
||||||
if (ip > radius) {
|
} else {
|
||||||
qres.add (ip, j);
|
if (ip > radius) {
|
||||||
|
qres.add (ip, j);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -883,7 +887,7 @@ static void range_search_blas (
|
||||||
InterruptCallback::check ();
|
InterruptCallback::check ();
|
||||||
}
|
}
|
||||||
|
|
||||||
RangeSearchPartialResult::merge (partial_results);
|
// RangeSearchPartialResult::merge (partial_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -892,12 +896,16 @@ static void range_search_sse (const float * x,
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *res)
|
std::vector<RangeSearchPartialResult*> &res,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
RangeSearchPartialResult pres (res);
|
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||||
|
tmp_res->buffer_size = buffer_size;
|
||||||
|
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||||
|
|
||||||
#pragma omp for
|
#pragma omp for
|
||||||
for (size_t i = 0; i < nx; i++) {
|
for (size_t i = 0; i < nx; i++) {
|
||||||
|
@ -905,9 +913,60 @@ static void range_search_sse (const float * x,
|
||||||
const float * y_ = y;
|
const float * y_ = y;
|
||||||
size_t j;
|
size_t j;
|
||||||
|
|
||||||
RangeQueryResult & qres = pres.new_result (i);
|
RangeQueryResult & qres = pres->new_result (i);
|
||||||
|
|
||||||
for (j = 0; j < ny; j++) {
|
for (j = 0; j < ny; j++) {
|
||||||
|
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||||
|
if (compute_l2) {
|
||||||
|
float disij = fvec_L2sqr (x_, y_, d);
|
||||||
|
if (disij < radius) {
|
||||||
|
qres.add (disij, j);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float ip = fvec_inner_product (x_, y_, d);
|
||||||
|
if (ip > radius) {
|
||||||
|
qres.add (ip, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y_ += d;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
#pragma omp critical
|
||||||
|
res.push_back(pres);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check just at the end because the use case is typically just
|
||||||
|
// when the nb of queries is low.
|
||||||
|
InterruptCallback::check();
|
||||||
|
}
|
||||||
|
|
||||||
|
// range search by sse when nq = 1, namely single query situation
|
||||||
|
template <bool compute_l2>
|
||||||
|
static void range_search_sse_sq (const float * x,
|
||||||
|
const float * y,
|
||||||
|
size_t d, size_t nx, size_t ny,
|
||||||
|
float radius,
|
||||||
|
std::vector<RangeSearchPartialResult*> &res,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
|
{
|
||||||
|
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||||
|
tmp_res->buffer_size = buffer_size;
|
||||||
|
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||||
|
|
||||||
|
const float * x_ = x;
|
||||||
|
size_t j;
|
||||||
|
RangeQueryResult & qres = pres->new_result (0);
|
||||||
|
|
||||||
|
#pragma omp for
|
||||||
|
for (j = 0; j < ny; j++) {
|
||||||
|
const float * y_ = y + j * d;
|
||||||
|
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||||
if (compute_l2) {
|
if (compute_l2) {
|
||||||
float disij = fvec_L2sqr (x_, y_, d);
|
float disij = fvec_L2sqr (x_, y_, d);
|
||||||
if (disij < radius) {
|
if (disij < radius) {
|
||||||
|
@ -919,11 +978,10 @@ static void range_search_sse (const float * x,
|
||||||
qres.add (ip, j);
|
qres.add (ip, j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
y_ += d;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
pres.finalize ();
|
#pragma omp critical
|
||||||
|
res.push_back(pres);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check just at the end because the use case is typically just
|
// check just at the end because the use case is typically just
|
||||||
|
@ -932,21 +990,24 @@ static void range_search_sse (const float * x,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void range_search_L2sqr (
|
void range_search_L2sqr (
|
||||||
const float * x,
|
const float * x,
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *res)
|
std::vector<RangeSearchPartialResult*> &res,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (nx < distance_compute_blas_threshold) {
|
if (nx < distance_compute_blas_threshold) {
|
||||||
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
if (nx == 1) {
|
||||||
|
range_search_sse_sq<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
|
} else {
|
||||||
|
range_search_sse<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
range_search_blas<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -955,17 +1016,21 @@ void range_search_inner_product (
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *res)
|
std::vector<RangeSearchPartialResult*> &res,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (nx < distance_compute_blas_threshold) {
|
if (nx < distance_compute_blas_threshold) {
|
||||||
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
if (nx == 1)
|
||||||
|
range_search_sse_sq<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
|
else
|
||||||
|
range_search_sse<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
} else {
|
} else {
|
||||||
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
range_search_blas<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void pairwise_L2sqr (int64_t d,
|
void pairwise_L2sqr (int64_t d,
|
||||||
int64_t nq, const float *xq,
|
int64_t nq, const float *xq,
|
||||||
int64_t nb, const float *xb,
|
int64_t nb, const float *xb,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <faiss/utils/Heap.h>
|
#include <faiss/utils/Heap.h>
|
||||||
#include <faiss/utils/ConcurrentBitset.h>
|
#include <faiss/utils/ConcurrentBitset.h>
|
||||||
#include <faiss/utils/BitsetView.h>
|
#include <faiss/utils/BitsetView.h>
|
||||||
|
#include <faiss/impl/AuxIndexStructures.h>
|
||||||
|
|
||||||
|
|
||||||
namespace faiss {
|
namespace faiss {
|
||||||
|
@ -241,15 +242,19 @@ void range_search_L2sqr (
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *result);
|
std::vector<RangeSearchPartialResult*> &result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset = nullptr);
|
||||||
|
|
||||||
/// same as range_search_L2sqr for the inner product similarity
|
/// same as range_search_L2sqr for the inner product similarity
|
||||||
void range_search_inner_product (
|
void range_search_inner_product (
|
||||||
const float * x,
|
const float * x,
|
||||||
const float * y,
|
const float * y,
|
||||||
size_t d, size_t nx, size_t ny,
|
size_t d, size_t nx, size_t ny,
|
||||||
float radius,
|
float radius,
|
||||||
RangeSearchResult *result);
|
std::vector<RangeSearchPartialResult*> &result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset = nullptr);
|
||||||
|
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
|
|
|
@ -744,6 +744,7 @@ void hammings_knn_mc(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class HammingComputer>
|
template <class HammingComputer>
|
||||||
static
|
static
|
||||||
void hamming_range_search_template (
|
void hamming_range_search_template (
|
||||||
|
@ -753,28 +754,33 @@ void hamming_range_search_template (
|
||||||
size_t nb,
|
size_t nb,
|
||||||
int radius,
|
int radius,
|
||||||
size_t code_size,
|
size_t code_size,
|
||||||
RangeSearchResult *res)
|
std::vector<RangeSearchPartialResult*> &result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView &bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
RangeSearchPartialResult pres (res);
|
RangeSearchResult *tmp_res = new RangeSearchResult(na);
|
||||||
|
tmp_res->buffer_size = buffer_size;
|
||||||
|
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||||
|
|
||||||
|
HammingComputer hc (a, code_size);
|
||||||
|
const uint8_t * yi = b;
|
||||||
|
RangeQueryResult & qres = pres->new_result (0);
|
||||||
|
|
||||||
#pragma omp for
|
#pragma omp for
|
||||||
for (size_t i = 0; i < na; i++) {
|
for (size_t j = 0; j < nb; j++) {
|
||||||
HammingComputer hc (a + i * code_size, code_size);
|
if (bitset.empty() || !bitset.test((ConcurrentBitset::id_type_t)j)) {
|
||||||
const uint8_t * yi = b;
|
int dis = hc.hamming (yi + j * code_size);
|
||||||
RangeQueryResult & qres = pres.new_result (i);
|
|
||||||
|
|
||||||
for (size_t j = 0; j < nb; j++) {
|
|
||||||
int dis = hc.hamming (yi);
|
|
||||||
if (dis < radius) {
|
if (dis < radius) {
|
||||||
qres.add(dis, j);
|
qres.add(dis, j);
|
||||||
}
|
}
|
||||||
yi += code_size;
|
|
||||||
}
|
}
|
||||||
|
// yi += code_size;
|
||||||
}
|
}
|
||||||
pres.finalize ();
|
#pragma omp critical
|
||||||
|
result.push_back(pres);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -785,10 +791,12 @@ void hamming_range_search (
|
||||||
size_t nb,
|
size_t nb,
|
||||||
int radius,
|
int radius,
|
||||||
size_t code_size,
|
size_t code_size,
|
||||||
RangeSearchResult *result)
|
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView& bitset)
|
||||||
{
|
{
|
||||||
|
|
||||||
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result)
|
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result, buffer_size, bitset)
|
||||||
|
|
||||||
switch(code_size) {
|
switch(code_size) {
|
||||||
case 4: HC(HammingComputer4); break;
|
case 4: HC(HammingComputer4); break;
|
||||||
|
@ -805,8 +813,6 @@ void hamming_range_search (
|
||||||
#undef HC
|
#undef HC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/* Count number of matches given a max threshold */
|
/* Count number of matches given a max threshold */
|
||||||
void hamming_count_thres (
|
void hamming_count_thres (
|
||||||
const uint8_t * bs1,
|
const uint8_t * bs1,
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include <faiss/utils/Heap.h>
|
#include <faiss/utils/Heap.h>
|
||||||
#include <faiss/utils/ConcurrentBitset.h>
|
#include <faiss/utils/ConcurrentBitset.h>
|
||||||
#include <faiss/utils/BitsetView.h>
|
#include <faiss/utils/BitsetView.h>
|
||||||
|
#include <faiss/impl/AuxIndexStructures.h>
|
||||||
|
|
||||||
/* The Hamming distance type */
|
/* The Hamming distance type */
|
||||||
typedef int32_t hamdis_t;
|
typedef int32_t hamdis_t;
|
||||||
|
@ -192,8 +193,9 @@ void hamming_range_search (
|
||||||
size_t nb,
|
size_t nb,
|
||||||
int radius,
|
int radius,
|
||||||
size_t ncodes,
|
size_t ncodes,
|
||||||
RangeSearchResult *result);
|
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||||
|
size_t buffer_size,
|
||||||
|
const BitsetView& bitset);
|
||||||
|
|
||||||
/* Counting the number of matches or of cross-matches (without returning them)
|
/* Counting the number of matches or of cross-matches (without returning them)
|
||||||
For use with function that assume pre-allocated memory */
|
For use with function that assume pre-allocated memory */
|
||||||
|
|
|
@ -33,6 +33,7 @@ set(util_srcs
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||||
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.cpp
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/Statistics.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/Statistics.cpp
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp
|
||||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
|
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||||
|
|
||||||
#include "knowhere/common/Exception.h"
|
#include "knowhere/common/Exception.h"
|
||||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||||
|
@ -39,6 +40,7 @@ class BinaryIDMAPTest : public DataGen, public TestWithParam<std::string> {
|
||||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
||||||
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
|
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
|
||||||
|
|
||||||
|
/*
|
||||||
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||||
ASSERT_TRUE(!xb_bin.empty());
|
ASSERT_TRUE(!xb_bin.empty());
|
||||||
|
|
||||||
|
@ -157,3 +159,89 @@ TEST_P(BinaryIDMAPTest, binaryidmap_slice) {
|
||||||
// PrintResult(result, nq, k);
|
// PrintResult(result, nq, k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_P(BinaryIDMAPTest, binaryidmap_range_search) {
|
||||||
|
std::string MetricType = GetParam();
|
||||||
|
milvus::knowhere::Config conf{
|
||||||
|
{milvus::knowhere::meta::DIM, dim},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||||
|
{milvus::knowhere::Metric::TYPE, MetricType},
|
||||||
|
};
|
||||||
|
|
||||||
|
int hamming_radius = 10;
|
||||||
|
auto hamming_dis = [] (const int64_t *pa, const int64_t *pb) -> int {
|
||||||
|
return __builtin_popcountl((*pa) ^ (*pb));
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||||
|
std::vector<size_t> bf_cnt(nq, 0);
|
||||||
|
|
||||||
|
auto bruteforce = [&] () {
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
const int64_t *pq = reinterpret_cast<int64_t*>(xq_bin.data()) + i;
|
||||||
|
const int64_t *pb = reinterpret_cast<int64_t*>(xb_bin.data());
|
||||||
|
for (auto j = 0; j < nb; ++ j) {
|
||||||
|
auto dist = hamming_dis(pq, pb + j);
|
||||||
|
if (dist < hamming_radius) {
|
||||||
|
idmap[i][j] = true;
|
||||||
|
bf_cnt[i] ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bruteforce();
|
||||||
|
|
||||||
|
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
int query_i_cnt = 0;
|
||||||
|
int correct_cnt = 0;
|
||||||
|
for (auto &res_space: results[i]) {
|
||||||
|
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||||
|
for (auto j = 0; j < qnr; ++ j) {
|
||||||
|
auto bno = j / res_space->buffer_size;
|
||||||
|
auto pos = j % res_space->buffer_size;
|
||||||
|
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||||
|
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||||
|
correct_cnt ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
// serialize index
|
||||||
|
index_->Train(base_dataset, conf);
|
||||||
|
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||||
|
EXPECT_EQ(index_->Count(), nb);
|
||||||
|
EXPECT_EQ(index_->Dim(), dim);
|
||||||
|
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
|
||||||
|
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(results);
|
||||||
|
//
|
||||||
|
auto binaryset = index_->Serialize(conf);
|
||||||
|
index_->Load(binaryset);
|
||||||
|
|
||||||
|
EXPECT_EQ(index_->Count(), nb);
|
||||||
|
EXPECT_EQ(index_->Dim(), dim);
|
||||||
|
{
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
|
||||||
|
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(rresults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include <fiu/fiu-local.h>
|
#include <fiu/fiu-local.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||||
|
|
||||||
#include "knowhere/common/Exception.h"
|
#include "knowhere/common/Exception.h"
|
||||||
#include "knowhere/index/IndexType.h"
|
#include "knowhere/index/IndexType.h"
|
||||||
|
@ -209,6 +210,240 @@ TEST_P(IDMAPTest, idmap_slice) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(IDMAPTest, idmap_range_search_l2) {
|
||||||
|
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||||
|
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
|
||||||
|
|
||||||
|
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||||
|
float ret = 0;
|
||||||
|
for (auto i = 0; i < dim; ++ i) {
|
||||||
|
auto dif = (pa[i] - pb[i]);
|
||||||
|
ret += dif * dif;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||||
|
std::vector<size_t> bf_cnt(nq, 0);
|
||||||
|
|
||||||
|
auto bruteforce = [&] () {
|
||||||
|
auto rds = radius * radius;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
const float *pq = xq.data() + i * dim;
|
||||||
|
for (auto j = 0; j < nb; ++ j) {
|
||||||
|
const float *pb = xb.data() + j * dim;
|
||||||
|
auto dist = l2dis(pq, pb, dim);
|
||||||
|
if (dist < rds) {
|
||||||
|
idmap[i][j] = true;
|
||||||
|
bf_cnt[i] ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bruteforce();
|
||||||
|
|
||||||
|
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||||
|
{ // compare the result
|
||||||
|
// std::cout << "size of result: " << results.size() << std::endl;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
int correct_cnt = 0;
|
||||||
|
// std::cout << "query id = " << i << ", result[i].size = " << results[i].size() << std::endl;
|
||||||
|
for (auto &res_space: results[i]) {
|
||||||
|
// std::cout << "buffer size = " << res_space->buffer_size << ", wp = " << res_space->wp << ", size of buffers = " << res_space->buffers.size() << std::endl;
|
||||||
|
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||||
|
// std::cout << "qnr = " << qnr << std::endl;
|
||||||
|
for (auto j = 0; j < qnr; ++ j) {
|
||||||
|
auto bno = j / res_space->buffer_size;
|
||||||
|
auto pos = j % res_space->buffer_size;
|
||||||
|
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||||
|
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||||
|
correct_cnt ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
index_->Train(base_dataset, conf);
|
||||||
|
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||||
|
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||||
|
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(results);
|
||||||
|
|
||||||
|
auto binaryset = index_->Serialize(conf);
|
||||||
|
index_->Load(binaryset);
|
||||||
|
|
||||||
|
EXPECT_EQ(index_->Count(), nb);
|
||||||
|
EXPECT_EQ(index_->Dim(), dim);
|
||||||
|
{ // query again and compare the result
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||||
|
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(rresults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(IDMAPTest, idmap_range_search_ip) {
|
||||||
|
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||||
|
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::IP}};
|
||||||
|
|
||||||
|
auto ipdis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||||
|
float ret = 0;
|
||||||
|
for (auto i = 0; i < dim; ++ i) {
|
||||||
|
ret += pa[i] * pb[i];
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||||
|
std::vector<size_t> bf_cnt(nq, 0);
|
||||||
|
|
||||||
|
auto bruteforce = [&] () {
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
const float *pq = xq.data() + i * dim;
|
||||||
|
for (auto j = 0; j < nb; ++ j) {
|
||||||
|
const float *pb = xb.data() + j * dim;
|
||||||
|
auto dist = ipdis(pq, pb, dim);
|
||||||
|
if (dist > radius) {
|
||||||
|
idmap[i][j] = true;
|
||||||
|
bf_cnt[i] ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bruteforce();
|
||||||
|
|
||||||
|
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||||
|
{ // compare the result
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
int correct_cnt = 0;
|
||||||
|
for (auto &res_space: results[i]) {
|
||||||
|
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||||
|
for (auto j = 0; j < qnr; ++ j) {
|
||||||
|
auto bno = j / res_space->buffer_size;
|
||||||
|
auto pos = j % res_space->buffer_size;
|
||||||
|
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||||
|
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||||
|
correct_cnt ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
index_->Train(base_dataset, conf);
|
||||||
|
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||||
|
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||||
|
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(results);
|
||||||
|
|
||||||
|
auto binaryset = index_->Serialize(conf);
|
||||||
|
index_->Load(binaryset);
|
||||||
|
|
||||||
|
EXPECT_EQ(index_->Count(), nb);
|
||||||
|
EXPECT_EQ(index_->Dim(), dim);
|
||||||
|
{ // query again and compare the result
|
||||||
|
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||||
|
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_res(rresults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(IDMAPTest, idmap_dynamic_result_set) {
|
||||||
|
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||||
|
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||||
|
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
|
||||||
|
|
||||||
|
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||||
|
float ret = 0;
|
||||||
|
for (auto i = 0; i < dim; ++ i) {
|
||||||
|
auto dif = (pa[i] - pb[i]);
|
||||||
|
ret += dif * dif;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||||
|
std::vector<size_t> bf_cnt(nq, 0);
|
||||||
|
|
||||||
|
auto bruteforce = [&] () {
|
||||||
|
auto rds = radius * radius;
|
||||||
|
for (auto i = 0; i < nq; ++ i) {
|
||||||
|
const float *pq = xq.data() + i * dim;
|
||||||
|
for (auto j = 0; j < nb; ++ j) {
|
||||||
|
const float *pb = xb.data() + j * dim;
|
||||||
|
auto dist = l2dis(pq, pb, dim);
|
||||||
|
if (dist < rds) {
|
||||||
|
idmap[i][j] = true;
|
||||||
|
bf_cnt[i] ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bruteforce();
|
||||||
|
|
||||||
|
auto check_rst = [&] (milvus::knowhere::DynamicResultSet &rst, milvus::knowhere::ResultSetPostProcessType rspt) {
|
||||||
|
{ // compare the result
|
||||||
|
for (auto i = 0; i < rst.count - 1; ++ i) {
|
||||||
|
if (rspt == milvus::knowhere::ResultSetPostProcessType::SortAsc)
|
||||||
|
ASSERT_LE(rst.distances.get() + i, rst.distances.get() + i + 1);
|
||||||
|
else if (rspt == milvus::knowhere::ResultSetPostProcessType::SortDesc)
|
||||||
|
ASSERT_GE(rst.distances.get() + i, rst.distances.get() + i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
milvus::knowhere::DynamicResultCollector collector;
|
||||||
|
index_->Train(base_dataset, conf);
|
||||||
|
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||||
|
|
||||||
|
for (auto i = 0; i < 3; ++ i) {
|
||||||
|
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data());
|
||||||
|
collector.Append(index_->QueryByDistance(qd, conf, nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rst = collector.Merge(1000, milvus::knowhere::ResultSetPostProcessType::SortAsc);
|
||||||
|
ASSERT_LE(rst.count, 1000);
|
||||||
|
|
||||||
|
check_rst(rst, milvus::knowhere::ResultSetPostProcessType::SortAsc);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef MILVUS_GPU_VERSION
|
#ifdef MILVUS_GPU_VERSION
|
||||||
TEST_P(IDMAPTest, idmap_copy) {
|
TEST_P(IDMAPTest, idmap_copy) {
|
||||||
ASSERT_TRUE(!xb.empty());
|
ASSERT_TRUE(!xb.empty());
|
||||||
|
|
|
@ -41,6 +41,8 @@ class DataGen {
|
||||||
int nq = 10;
|
int nq = 10;
|
||||||
int dim = 64;
|
int dim = 64;
|
||||||
int k = 10;
|
int k = 10;
|
||||||
|
int buffer_size = 16384;
|
||||||
|
float radius = 2.8;
|
||||||
std::vector<float> xb;
|
std::vector<float> xb;
|
||||||
std::vector<float> xq;
|
std::vector<float> xq;
|
||||||
std::vector<uint8_t> xb_bin;
|
std::vector<uint8_t> xb_bin;
|
||||||
|
|
Loading…
Reference in New Issue