mirror of https://github.com/milvus-io/milvus.git
Single query by distance (#4648)
* IndexFlat::QueryByDistance unittest passed Signed-off-by: cmli <chengming.li@zilliz.com> * change the datastructure of DynamicResultSet Signed-off-by: cmli <chengming.li@zilliz.com> * unittest of IndexFlat::QueryByDistance and IndexBinaryFalt::QueryByDistance has passed Signed-off-by: cmli <chengming.li@zilliz.com> * update the datastructure of DynamicResultSet... Signed-off-by: cmli <chengming.li@zilliz.com> * update the interface and the data structure of dynamic result set, compile passed Signed-off-by: cmli <chengming.li@zilliz.com> * fix unittest, add test for new data structure, to be debug... Signed-off-by: cmli <chengming.li@zilliz.com> * unittest passed Signed-off-by: cmli <chengming.li@zilliz.com> * add MapUids Signed-off-by: cmli <chengming.li@zilliz.com> * update code by the advise from review Signed-off-by: cmli <chengming.li@zilliz.com> Co-authored-by: cmli <chengming.li@zilliz.com>pull/4659/head
parent
e644680dd6
commit
d64b1dd165
|
@ -57,6 +57,7 @@ set(vector_index_srcs
|
|||
knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||
knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
knowhere/index/vector_index/helpers/DynamicResultSet.cpp
|
||||
knowhere/index/vector_index/impl/nsg/Distance.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSG.cpp
|
||||
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp
|
||||
|
|
|
@ -65,6 +65,38 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const fa
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DynamicResultSegment
|
||||
BinaryIDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset)
|
||||
if (rows != 1) {
|
||||
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||
}
|
||||
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE)) {
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
}
|
||||
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||
DynamicResultSegment result;
|
||||
auto radius = config[IndexParams::range_search_radius].get<int>();
|
||||
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||
: 16384;
|
||||
auto real_idx = dynamic_cast<faiss::IndexBinaryFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexBinaryFlat type!");
|
||||
}
|
||||
real_idx->range_search(rows, reinterpret_cast<const uint8_t*>(p_data), radius, res, buffer_size, bitset);
|
||||
ExchangeDataset(result, res);
|
||||
MapUids(result);
|
||||
index_->metric_type = default_type;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Count() {
|
||||
if (!index_) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -46,6 +47,9 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
|||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
|
||||
|
||||
DynamicResultSegment
|
||||
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
|
|
|
@ -98,6 +98,41 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::B
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
DynamicResultSegment
|
||||
IDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
|
||||
const faiss::BitsetView& bitset) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset)
|
||||
if (rows != 1) {
|
||||
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
|
||||
}
|
||||
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE)) {
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
}
|
||||
std::vector<faiss::RangeSearchPartialResult*> res;
|
||||
DynamicResultSegment result;
|
||||
auto radius = config[IndexParams::range_search_radius].get<float>();
|
||||
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
|
||||
? config[IndexParams::range_search_buffer_size].get<size_t>()
|
||||
: 16384;
|
||||
auto real_idx = dynamic_cast<faiss::IndexFlat*>(index_.get());
|
||||
if (real_idx == nullptr) {
|
||||
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexFlat type!");
|
||||
}
|
||||
if (index_->metric_type == faiss::MetricType::METRIC_L2) {
|
||||
radius *= radius;
|
||||
}
|
||||
real_idx->range_search(rows, reinterpret_cast<const float*>(p_data), radius, res, buffer_size, bitset);
|
||||
ExchangeDataset(result, res);
|
||||
MapUids(result);
|
||||
index_->metric_type = default_type;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IDMAP::Count() {
|
||||
if (!index_) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -45,6 +46,9 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
|||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
|
||||
|
||||
DynamicResultSegment
|
||||
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "knowhere/index/Index.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
#include "knowhere/index/vector_index/Statistics.h"
|
||||
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
@ -94,6 +95,23 @@ class VecIndex : public Index {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
MapUids(DynamicResultSegment& milvus_dataset) {
|
||||
if (uids_) {
|
||||
for (auto& mrspr : milvus_dataset) {
|
||||
for (auto j = 0; j < mrspr->buffers.size(); ++j) {
|
||||
auto buf = mrspr->buffers[j];
|
||||
auto len = j + 1 == mrspr->buffers.size() ? mrspr->wp : mrspr->buffer_size;
|
||||
for (auto i = 0; i < len; ++i) {
|
||||
if (buf.ids[i] >= 0) {
|
||||
buf.ids[i] = uids_->at(buf.ids[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
UidsSize() {
|
||||
return uids_ ? uids_->size() * sizeof(IDType) : 0;
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <src/index/knowhere/knowhere/common/Exception.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.h>
|
||||
#include <src/index/thirdparty/faiss/impl/AuxIndexStructures.h>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
/***********************************************************************
|
||||
* DynamicResultSet
|
||||
***********************************************************************/
|
||||
|
||||
void
|
||||
DynamicResultSet::do_alloction() {
|
||||
if (count <= 0) {
|
||||
KNOWHERE_THROW_MSG("DynamicResultSet::do_alloction failed because of count <= 0");
|
||||
}
|
||||
labels = std::shared_ptr<idx_t[]>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||
distances = std::shared_ptr<float[]>(new float[count], std::default_delete<float[]>());
|
||||
// labels = std::make_shared<idx_t []>(new idx_t[count], std::default_delete<idx_t[]>());
|
||||
// distances = std::make_shared<float []>(new float[count], std::default_delete<float[]>());
|
||||
}
|
||||
|
||||
void
|
||||
DynamicResultSet::do_sort(ResultSetPostProcessType postProcessType) {
|
||||
if (postProcessType == ResultSetPostProcessType::SortAsc) {
|
||||
quick_sort<true>(0, count);
|
||||
} else if (postProcessType == ResultSetPostProcessType::SortDesc) {
|
||||
quick_sort<false>(0, count);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("invalid sort type!");
|
||||
}
|
||||
}
|
||||
|
||||
template <bool asc>
|
||||
void
|
||||
DynamicResultSet::quick_sort(size_t lp, size_t rp) {
|
||||
auto len = rp - lp;
|
||||
if (len <= 1) {
|
||||
return;
|
||||
}
|
||||
auto pvot = lp + (len >> 1);
|
||||
size_t low = lp;
|
||||
size_t high = rp - 1;
|
||||
auto pids = labels.get();
|
||||
auto pdis = distances.get();
|
||||
std::swap(pdis[pvot], pdis[high]);
|
||||
std::swap(pids[pvot], pids[high]);
|
||||
if (asc) {
|
||||
while (low < high) {
|
||||
while (low < high && pdis[low] <= pdis[high]) {
|
||||
low++;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
high--;
|
||||
while (low < high && pdis[high] >= pdis[low]) {
|
||||
high--;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
low++;
|
||||
}
|
||||
} else {
|
||||
while (low < high) {
|
||||
while (low < high && pdis[low] >= pdis[high]) {
|
||||
low++;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
high--;
|
||||
while (low < high && pdis[high] <= pdis[low]) {
|
||||
high--;
|
||||
}
|
||||
if (low == high) {
|
||||
break;
|
||||
}
|
||||
std::swap(pdis[low], pdis[high]);
|
||||
std::swap(pids[low], pids[high]);
|
||||
low++;
|
||||
}
|
||||
}
|
||||
quick_sort<asc>(lp, low);
|
||||
quick_sort<asc>(low, rp);
|
||||
}
|
||||
|
||||
DynamicResultSet
|
||||
DynamicResultCollector::Merge(size_t limit, ResultSetPostProcessType postProcessType) {
|
||||
if (limit <= 0) {
|
||||
KNOWHERE_THROW_MSG("limit must > 0!");
|
||||
}
|
||||
DynamicResultSet ret;
|
||||
auto seg_num = seg_results.size();
|
||||
std::vector<size_t> boundaries(seg_num + 1, 0);
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < seg_num; ++i) {
|
||||
for (auto& pseg : seg_results[i]) {
|
||||
boundaries[i] += (pseg->buffer_size * pseg->buffers.size() - pseg->buffer_size + pseg->wp);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0, ofs = 0; i <= seg_num; ++i) {
|
||||
auto bn = boundaries[i];
|
||||
boundaries[i] = ofs;
|
||||
ofs += bn;
|
||||
// boundaries[i] += boundaries[i - 1];
|
||||
}
|
||||
ret.count = boundaries[seg_num] <= limit ? boundaries[seg_num] : limit;
|
||||
ret.do_alloction();
|
||||
|
||||
// abandon redundancy answers randomly
|
||||
// abandon strategy: keep the top limit sequentially
|
||||
int32_t pos = 1;
|
||||
for (int i = 1; i < boundaries.size(); ++i) {
|
||||
if (boundaries[i] >= ret.count) {
|
||||
pos = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
pos--; // last segment id
|
||||
// full copy
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < pos - 1; ++i) {
|
||||
for (auto& pseg : seg_results[i]) {
|
||||
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||
pseg->copy_range(0, len, ret.labels.get() + boundaries[i], ret.distances.get() + boundaries[i]);
|
||||
boundaries[i] += len;
|
||||
}
|
||||
}
|
||||
// partial copy
|
||||
auto last_len = ret.count - boundaries[pos];
|
||||
for (auto& pseg : seg_results[pos]) {
|
||||
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
|
||||
auto ncopy = last_len > len ? len : last_len;
|
||||
pseg->copy_range(0, ncopy, ret.labels.get() + boundaries[pos], ret.distances.get() + boundaries[pos]);
|
||||
boundaries[pos] += ncopy;
|
||||
last_len -= ncopy;
|
||||
if (last_len <= 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (postProcessType != ResultSetPostProcessType::None) {
|
||||
ret.do_sort(postProcessType);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
DynamicResultCollector::Append(milvus::knowhere::DynamicResultSegment&& seg_result) {
|
||||
seg_results.push_back(std::move(seg_result));
|
||||
}
|
||||
|
||||
void
|
||||
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset) {
|
||||
for (auto& prspr : faiss_dataset) {
|
||||
auto mrspr = new DynamicResultFragment(prspr->res->buffer_size);
|
||||
mrspr->wp = prspr->wp;
|
||||
mrspr->buffers.resize(prspr->buffers.size());
|
||||
for (auto i = 0; i < prspr->buffers.size(); ++i) {
|
||||
mrspr->buffers[i].ids = prspr->buffers[i].ids;
|
||||
mrspr->buffers[i].dis = prspr->buffers[i].dis;
|
||||
prspr->buffers[i].ids = nullptr;
|
||||
prspr->buffers[i].dis = nullptr;
|
||||
}
|
||||
delete prspr->res;
|
||||
milvus_dataset.push_back(mrspr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "faiss/impl/AuxIndexStructures.h"
|
||||
#include "knowhere/common/Typedef.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
|
||||
enum class ResultSetPostProcessType { None = 0, SortDesc, SortAsc };
|
||||
|
||||
using idx_t = int64_t;
|
||||
struct DynamicResultSet {
|
||||
std::shared_ptr<idx_t[]> labels; /// result for query i is labels[lims[i]:lims[i + 1]]
|
||||
std::shared_ptr<float[]> distances; /// corresponding distances, not sorted
|
||||
size_t count; /// size of the result buffer's size, when reaches this size, auto start a new buffer
|
||||
void
|
||||
do_alloction();
|
||||
void
|
||||
do_sort(ResultSetPostProcessType postProcessType = ResultSetPostProcessType::SortAsc);
|
||||
|
||||
private:
|
||||
template <bool asc>
|
||||
void
|
||||
quick_sort(size_t lp, size_t rp);
|
||||
};
|
||||
|
||||
typedef faiss::BufferList DynamicResultFragment;
|
||||
typedef DynamicResultFragment* DynamicResultFragmentPtr;
|
||||
typedef std::vector<DynamicResultFragmentPtr> DynamicResultSegment;
|
||||
|
||||
struct DynamicResultCollector {
|
||||
public:
|
||||
DynamicResultSet
|
||||
Merge(size_t limit = 10000, ResultSetPostProcessType postProcessType = ResultSetPostProcessType::None);
|
||||
void
|
||||
Append(DynamicResultSegment&& seg_result);
|
||||
|
||||
private:
|
||||
std::vector<DynamicResultSegment> seg_results; /// unmerged results of every segments
|
||||
};
|
||||
|
||||
void
|
||||
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset);
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
|
@ -28,6 +28,10 @@ constexpr const char* DEVICEID = "gpu_id";
|
|||
}; // namespace meta
|
||||
|
||||
namespace IndexParams {
|
||||
// Range Search Params
|
||||
constexpr const char* range_search_radius = "range_search_radius";
|
||||
constexpr const char* range_search_buffer_size = "range_search_buffer_size";
|
||||
|
||||
// IVF Params
|
||||
constexpr const char* nprobe = "nprobe";
|
||||
constexpr const char* nlist = "nlist";
|
||||
|
|
|
@ -126,7 +126,17 @@ void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
|
|||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
{
|
||||
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
|
||||
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||
}
|
||||
|
||||
void IndexBinaryFlat::range_search(faiss::IndexBinary::idx_t n,
|
||||
const uint8_t* x,
|
||||
int radius,
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const faiss::BitsetView& bitset)
|
||||
{
|
||||
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include <faiss/IndexBinary.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -45,6 +46,11 @@ struct IndexBinaryFlat : IndexBinary {
|
|||
RangeSearchResult *result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
|
||||
void range_search(idx_t n, const uint8_t *x, int radius,
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView& bitset = nullptr); // const override
|
||||
|
||||
void reconstruct(idx_t key, uint8_t *recons) const override;
|
||||
|
||||
/** Remove some ids. Note that because of the indexing structure,
|
||||
|
|
|
@ -93,20 +93,29 @@ void IndexFlat::range_search (idx_t n, const float *x, float radius,
|
|||
RangeSearchResult *result,
|
||||
const BitsetView& bitset) const
|
||||
{
|
||||
FAISS_THROW_MSG("This interface is abandoned yet.");
|
||||
}
|
||||
|
||||
void IndexFlat::range_search(faiss::Index::idx_t n,
|
||||
const float* x,
|
||||
float radius,
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const faiss::BitsetView& bitset) {
|
||||
|
||||
switch (metric_type) {
|
||||
case METRIC_INNER_PRODUCT:
|
||||
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
||||
radius, result);
|
||||
radius, result, buffer_size, bitset);
|
||||
break;
|
||||
case METRIC_L2:
|
||||
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
|
||||
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result, buffer_size, bitset);
|
||||
break;
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void IndexFlat::compute_distance_subset (
|
||||
idx_t n,
|
||||
const float *x,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -50,6 +51,14 @@ struct IndexFlat: Index {
|
|||
RangeSearchResult* result,
|
||||
const BitsetView& bitset = nullptr) const override;
|
||||
|
||||
void range_search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
float radius,
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView& bitset = nullptr); // const override
|
||||
|
||||
void reconstruct(idx_t key, float* recons) const override;
|
||||
|
||||
/** compute distance with a subset of vectors
|
||||
|
|
|
@ -48,8 +48,11 @@ void RangeSearchResult::do_allocation () {
|
|||
}
|
||||
|
||||
RangeSearchResult::~RangeSearchResult () {
|
||||
if (labels)
|
||||
delete [] labels;
|
||||
if (distances)
|
||||
delete [] distances;
|
||||
if (lims)
|
||||
delete [] lims;
|
||||
}
|
||||
|
||||
|
@ -71,7 +74,9 @@ BufferList::BufferList (size_t buffer_size):
|
|||
BufferList::~BufferList ()
|
||||
{
|
||||
for (int i = 0; i < buffers.size(); i++) {
|
||||
if (buffers[i].ids)
|
||||
delete [] buffers[i].ids;
|
||||
if (buffers[i].dis)
|
||||
delete [] buffers[i].dis;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -807,13 +807,15 @@ void knn_L2sqr_by_idx (const float * x,
|
|||
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
||||
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
||||
*/
|
||||
template <bool compute_l2>
|
||||
template <bool compute_l2>
|
||||
static void range_search_blas (
|
||||
const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *result)
|
||||
std::vector<RangeSearchPartialResult*> &res,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
// BLAS does not like empty matrices
|
||||
|
@ -837,13 +839,13 @@ static void range_search_blas (
|
|||
fvec_norms_L2sqr (y_norms, y, d, ny);
|
||||
}
|
||||
|
||||
std::vector <RangeSearchPartialResult *> partial_results;
|
||||
|
||||
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
||||
size_t j1 = j0 + bs_y;
|
||||
if (j1 > ny) j1 = ny;
|
||||
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
||||
partial_results.push_back (pres);
|
||||
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||
tmp_res->buffer_size = buffer_size;
|
||||
RangeSearchPartialResult * pres = new RangeSearchPartialResult (tmp_res);
|
||||
res.push_back (pres);
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
||||
size_t i1 = i0 + bs_x;
|
||||
|
@ -867,6 +869,7 @@ static void range_search_blas (
|
|||
|
||||
for (size_t j = j0; j < j1; j++) {
|
||||
float ip = *ip_line++;
|
||||
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||
if (compute_l2) {
|
||||
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
||||
if (dis < radius) {
|
||||
|
@ -880,10 +883,11 @@ static void range_search_blas (
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
|
||||
RangeSearchPartialResult::merge (partial_results);
|
||||
// RangeSearchPartialResult::merge (partial_results);
|
||||
}
|
||||
|
||||
|
||||
|
@ -892,12 +896,16 @@ static void range_search_sse (const float * x,
|
|||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *res)
|
||||
std::vector<RangeSearchPartialResult*> &res,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
RangeSearchPartialResult pres (res);
|
||||
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||
tmp_res->buffer_size = buffer_size;
|
||||
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||
|
||||
#pragma omp for
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
|
@ -905,9 +913,10 @@ static void range_search_sse (const float * x,
|
|||
const float * y_ = y;
|
||||
size_t j;
|
||||
|
||||
RangeQueryResult & qres = pres.new_result (i);
|
||||
RangeQueryResult & qres = pres->new_result (i);
|
||||
|
||||
for (j = 0; j < ny; j++) {
|
||||
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||
if (compute_l2) {
|
||||
float disij = fvec_L2sqr (x_, y_, d);
|
||||
if (disij < radius) {
|
||||
|
@ -919,11 +928,60 @@ static void range_search_sse (const float * x,
|
|||
qres.add (ip, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
y_ += d;
|
||||
}
|
||||
|
||||
}
|
||||
pres.finalize ();
|
||||
#pragma omp critical
|
||||
res.push_back(pres);
|
||||
}
|
||||
|
||||
// check just at the end because the use case is typically just
|
||||
// when the nb of queries is low.
|
||||
InterruptCallback::check();
|
||||
}
|
||||
|
||||
// range search by sse when nq = 1, namely single query situation
|
||||
template <bool compute_l2>
|
||||
static void range_search_sse_sq (const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
std::vector<RangeSearchPartialResult*> &res,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
|
||||
tmp_res->buffer_size = buffer_size;
|
||||
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||
|
||||
const float * x_ = x;
|
||||
size_t j;
|
||||
RangeQueryResult & qres = pres->new_result (0);
|
||||
|
||||
#pragma omp for
|
||||
for (j = 0; j < ny; j++) {
|
||||
const float * y_ = y + j * d;
|
||||
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
|
||||
if (compute_l2) {
|
||||
float disij = fvec_L2sqr (x_, y_, d);
|
||||
if (disij < radius) {
|
||||
qres.add (disij, j);
|
||||
}
|
||||
} else {
|
||||
float ip = fvec_inner_product (x_, y_, d);
|
||||
if (ip > radius) {
|
||||
qres.add (ip, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp critical
|
||||
res.push_back(pres);
|
||||
}
|
||||
|
||||
// check just at the end because the use case is typically just
|
||||
|
@ -932,21 +990,24 @@ static void range_search_sse (const float * x,
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void range_search_L2sqr (
|
||||
const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *res)
|
||||
std::vector<RangeSearchPartialResult*> &res,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
if (nx < distance_compute_blas_threshold) {
|
||||
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
||||
if (nx == 1) {
|
||||
range_search_sse_sq<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
} else {
|
||||
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
||||
range_search_sse<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
}
|
||||
} else {
|
||||
range_search_blas<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -955,17 +1016,21 @@ void range_search_inner_product (
|
|||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *res)
|
||||
std::vector<RangeSearchPartialResult*> &res,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
if (nx < distance_compute_blas_threshold) {
|
||||
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
||||
if (nx == 1)
|
||||
range_search_sse_sq<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
else
|
||||
range_search_sse<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
} else {
|
||||
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
||||
range_search_blas<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void pairwise_L2sqr (int64_t d,
|
||||
int64_t nq, const float *xq,
|
||||
int64_t nb, const float *xb,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <faiss/utils/BitsetView.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -241,7 +242,9 @@ void range_search_L2sqr (
|
|||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *result);
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset = nullptr);
|
||||
|
||||
/// same as range_search_L2sqr for the inner product similarity
|
||||
void range_search_inner_product (
|
||||
|
@ -249,7 +252,9 @@ void range_search_inner_product (
|
|||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float radius,
|
||||
RangeSearchResult *result);
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset = nullptr);
|
||||
|
||||
|
||||
/***************************************************************************
|
||||
|
|
|
@ -744,6 +744,7 @@ void hammings_knn_mc(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class HammingComputer>
|
||||
static
|
||||
void hamming_range_search_template (
|
||||
|
@ -753,28 +754,33 @@ void hamming_range_search_template (
|
|||
size_t nb,
|
||||
int radius,
|
||||
size_t code_size,
|
||||
RangeSearchResult *res)
|
||||
std::vector<RangeSearchPartialResult*> &result,
|
||||
size_t buffer_size,
|
||||
const BitsetView &bitset)
|
||||
{
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
RangeSearchPartialResult pres (res);
|
||||
RangeSearchResult *tmp_res = new RangeSearchResult(na);
|
||||
tmp_res->buffer_size = buffer_size;
|
||||
auto pres = new RangeSearchPartialResult(tmp_res);
|
||||
|
||||
HammingComputer hc (a, code_size);
|
||||
const uint8_t * yi = b;
|
||||
RangeQueryResult & qres = pres->new_result (0);
|
||||
|
||||
#pragma omp for
|
||||
for (size_t i = 0; i < na; i++) {
|
||||
HammingComputer hc (a + i * code_size, code_size);
|
||||
const uint8_t * yi = b;
|
||||
RangeQueryResult & qres = pres.new_result (i);
|
||||
|
||||
for (size_t j = 0; j < nb; j++) {
|
||||
int dis = hc.hamming (yi);
|
||||
if (bitset.empty() || !bitset.test((ConcurrentBitset::id_type_t)j)) {
|
||||
int dis = hc.hamming (yi + j * code_size);
|
||||
if (dis < radius) {
|
||||
qres.add(dis, j);
|
||||
}
|
||||
yi += code_size;
|
||||
}
|
||||
// yi += code_size;
|
||||
}
|
||||
pres.finalize ();
|
||||
#pragma omp critical
|
||||
result.push_back(pres);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -785,10 +791,12 @@ void hamming_range_search (
|
|||
size_t nb,
|
||||
int radius,
|
||||
size_t code_size,
|
||||
RangeSearchResult *result)
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const BitsetView& bitset)
|
||||
{
|
||||
|
||||
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result)
|
||||
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result, buffer_size, bitset)
|
||||
|
||||
switch(code_size) {
|
||||
case 4: HC(HammingComputer4); break;
|
||||
|
@ -805,8 +813,6 @@ void hamming_range_search (
|
|||
#undef HC
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Count number of matches given a max threshold */
|
||||
void hamming_count_thres (
|
||||
const uint8_t * bs1,
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <faiss/utils/BitsetView.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
/* The Hamming distance type */
|
||||
typedef int32_t hamdis_t;
|
||||
|
@ -192,8 +193,9 @@ void hamming_range_search (
|
|||
size_t nb,
|
||||
int radius,
|
||||
size_t ncodes,
|
||||
RangeSearchResult *result);
|
||||
|
||||
std::vector<faiss::RangeSearchPartialResult*>& result,
|
||||
size_t buffer_size,
|
||||
const BitsetView& bitset);
|
||||
|
||||
/* Counting the number of matches or of cross-matches (without returning them)
|
||||
For use with function that assume pre-allocated memory */
|
||||
|
|
|
@ -33,6 +33,7 @@ set(util_srcs
|
|||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/Statistics.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
|
@ -39,6 +40,7 @@ class BinaryIDMAPTest : public DataGen, public TestWithParam<std::string> {
|
|||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
||||
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
|
||||
|
||||
/*
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
ASSERT_TRUE(!xb_bin.empty());
|
||||
|
||||
|
@ -157,3 +159,89 @@ TEST_P(BinaryIDMAPTest, binaryidmap_slice) {
|
|||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_range_search) {
|
||||
std::string MetricType = GetParam();
|
||||
milvus::knowhere::Config conf{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||
{milvus::knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
|
||||
int hamming_radius = 10;
|
||||
auto hamming_dis = [] (const int64_t *pa, const int64_t *pb) -> int {
|
||||
return __builtin_popcountl((*pa) ^ (*pb));
|
||||
};
|
||||
|
||||
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||
std::vector<size_t> bf_cnt(nq, 0);
|
||||
|
||||
auto bruteforce = [&] () {
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
const int64_t *pq = reinterpret_cast<int64_t*>(xq_bin.data()) + i;
|
||||
const int64_t *pb = reinterpret_cast<int64_t*>(xb_bin.data());
|
||||
for (auto j = 0; j < nb; ++ j) {
|
||||
auto dist = hamming_dis(pq, pb + j);
|
||||
if (dist < hamming_radius) {
|
||||
idmap[i][j] = true;
|
||||
bf_cnt[i] ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bruteforce();
|
||||
|
||||
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
int query_i_cnt = 0;
|
||||
int correct_cnt = 0;
|
||||
for (auto &res_space: results[i]) {
|
||||
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||
for (auto j = 0; j < qnr; ++ j) {
|
||||
auto bno = j / res_space->buffer_size;
|
||||
auto pos = j % res_space->buffer_size;
|
||||
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||
correct_cnt ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
|
||||
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(results);
|
||||
//
|
||||
auto binaryset = index_->Serialize(conf);
|
||||
index_->Load(binaryset);
|
||||
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
{
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
|
||||
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(rresults);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <fiu/fiu-local.h>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/IndexType.h"
|
||||
|
@ -209,6 +210,240 @@ TEST_P(IDMAPTest, idmap_slice) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_P(IDMAPTest, idmap_range_search_l2) {
|
||||
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
|
||||
|
||||
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||
float ret = 0;
|
||||
for (auto i = 0; i < dim; ++ i) {
|
||||
auto dif = (pa[i] - pb[i]);
|
||||
ret += dif * dif;
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||
std::vector<size_t> bf_cnt(nq, 0);
|
||||
|
||||
auto bruteforce = [&] () {
|
||||
auto rds = radius * radius;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
const float *pq = xq.data() + i * dim;
|
||||
for (auto j = 0; j < nb; ++ j) {
|
||||
const float *pb = xb.data() + j * dim;
|
||||
auto dist = l2dis(pq, pb, dim);
|
||||
if (dist < rds) {
|
||||
idmap[i][j] = true;
|
||||
bf_cnt[i] ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bruteforce();
|
||||
|
||||
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||
{ // compare the result
|
||||
// std::cout << "size of result: " << results.size() << std::endl;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
int correct_cnt = 0;
|
||||
// std::cout << "query id = " << i << ", result[i].size = " << results[i].size() << std::endl;
|
||||
for (auto &res_space: results[i]) {
|
||||
// std::cout << "buffer size = " << res_space->buffer_size << ", wp = " << res_space->wp << ", size of buffers = " << res_space->buffers.size() << std::endl;
|
||||
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||
// std::cout << "qnr = " << qnr << std::endl;
|
||||
for (auto j = 0; j < qnr; ++ j) {
|
||||
auto bno = j / res_space->buffer_size;
|
||||
auto pos = j % res_space->buffer_size;
|
||||
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||
correct_cnt ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(results);
|
||||
|
||||
auto binaryset = index_->Serialize(conf);
|
||||
index_->Load(binaryset);
|
||||
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
{ // query again and compare the result
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(rresults);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(IDMAPTest, idmap_range_search_ip) {
|
||||
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::IP}};
|
||||
|
||||
auto ipdis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||
float ret = 0;
|
||||
for (auto i = 0; i < dim; ++ i) {
|
||||
ret += pa[i] * pb[i];
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||
std::vector<size_t> bf_cnt(nq, 0);
|
||||
|
||||
auto bruteforce = [&] () {
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
const float *pq = xq.data() + i * dim;
|
||||
for (auto j = 0; j < nb; ++ j) {
|
||||
const float *pb = xb.data() + j * dim;
|
||||
auto dist = ipdis(pq, pb, dim);
|
||||
if (dist > radius) {
|
||||
idmap[i][j] = true;
|
||||
bf_cnt[i] ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bruteforce();
|
||||
|
||||
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
|
||||
{ // compare the result
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
int correct_cnt = 0;
|
||||
for (auto &res_space: results[i]) {
|
||||
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
|
||||
for (auto j = 0; j < qnr; ++ j) {
|
||||
auto bno = j / res_space->buffer_size;
|
||||
auto pos = j % res_space->buffer_size;
|
||||
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
|
||||
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
|
||||
correct_cnt ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(correct_cnt, bf_cnt[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> results;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(results);
|
||||
|
||||
auto binaryset = index_->Serialize(conf);
|
||||
index_->Load(binaryset);
|
||||
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
{ // query again and compare the result
|
||||
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
|
||||
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
compare_res(rresults);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(IDMAPTest, idmap_dynamic_result_set) {
|
||||
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::IndexParams::range_search_radius, radius},
|
||||
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
|
||||
|
||||
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
|
||||
float ret = 0;
|
||||
for (auto i = 0; i < dim; ++ i) {
|
||||
auto dif = (pa[i] - pb[i]);
|
||||
ret += dif * dif;
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
|
||||
std::vector<size_t> bf_cnt(nq, 0);
|
||||
|
||||
auto bruteforce = [&] () {
|
||||
auto rds = radius * radius;
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
const float *pq = xq.data() + i * dim;
|
||||
for (auto j = 0; j < nb; ++ j) {
|
||||
const float *pb = xb.data() + j * dim;
|
||||
auto dist = l2dis(pq, pb, dim);
|
||||
if (dist < rds) {
|
||||
idmap[i][j] = true;
|
||||
bf_cnt[i] ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bruteforce();
|
||||
|
||||
auto check_rst = [&] (milvus::knowhere::DynamicResultSet &rst, milvus::knowhere::ResultSetPostProcessType rspt) {
|
||||
{ // compare the result
|
||||
for (auto i = 0; i < rst.count - 1; ++ i) {
|
||||
if (rspt == milvus::knowhere::ResultSetPostProcessType::SortAsc)
|
||||
ASSERT_LE(rst.distances.get() + i, rst.distances.get() + i + 1);
|
||||
else if (rspt == milvus::knowhere::ResultSetPostProcessType::SortDesc)
|
||||
ASSERT_GE(rst.distances.get() + i, rst.distances.get() + i + 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
milvus::knowhere::DynamicResultCollector collector;
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||
|
||||
for (auto i = 0; i < 3; ++ i) {
|
||||
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data());
|
||||
collector.Append(index_->QueryByDistance(qd, conf, nullptr));
|
||||
}
|
||||
|
||||
auto rst = collector.Merge(1000, milvus::knowhere::ResultSetPostProcessType::SortAsc);
|
||||
ASSERT_LE(rst.count, 1000);
|
||||
|
||||
check_rst(rst, milvus::knowhere::ResultSetPostProcessType::SortAsc);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
TEST_P(IDMAPTest, idmap_copy) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
|
|
@ -41,6 +41,8 @@ class DataGen {
|
|||
int nq = 10;
|
||||
int dim = 64;
|
||||
int k = 10;
|
||||
int buffer_size = 16384;
|
||||
float radius = 2.8;
|
||||
std::vector<float> xb;
|
||||
std::vector<float> xq;
|
||||
std::vector<uint8_t> xb_bin;
|
||||
|
|
Loading…
Reference in New Issue