Single query by distance (#4648)

* IndexFlat::QueryByDistance unittest passed

Signed-off-by: cmli <chengming.li@zilliz.com>

* change the datastructure of DynamicResultSet

Signed-off-by: cmli <chengming.li@zilliz.com>

* unittest of IndexFlat::QueryByDistance and IndexBinaryFalt::QueryByDistance has passed

Signed-off-by: cmli <chengming.li@zilliz.com>

* update the datastructure of DynamicResultSet...

Signed-off-by: cmli <chengming.li@zilliz.com>

* update the interface and the data structure of dynamic result set, compile passed

Signed-off-by: cmli <chengming.li@zilliz.com>

* fix unittest, add test for new data structure, to be debug...

Signed-off-by: cmli <chengming.li@zilliz.com>

* unittest passed

Signed-off-by: cmli <chengming.li@zilliz.com>

* add MapUids

Signed-off-by: cmli <chengming.li@zilliz.com>

* update code by the advise from review

Signed-off-by: cmli <chengming.li@zilliz.com>

Co-authored-by: cmli <chengming.li@zilliz.com>
pull/4659/head
op-hunter 2021-01-26 09:44:52 +08:00 committed by GitHub
parent e644680dd6
commit d64b1dd165
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 866 additions and 71 deletions

View File

@ -57,6 +57,7 @@ set(vector_index_srcs
knowhere/index/vector_index/adapter/VectorAdapter.cpp
knowhere/index/vector_index/helpers/FaissIO.cpp
knowhere/index/vector_index/helpers/IndexParameter.cpp
knowhere/index/vector_index/helpers/DynamicResultSet.cpp
knowhere/index/vector_index/impl/nsg/Distance.cpp
knowhere/index/vector_index/impl/nsg/NSG.cpp
knowhere/index/vector_index/impl/nsg/NSGHelper.cpp

View File

@ -65,6 +65,38 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const fa
return ret_ds;
}
DynamicResultSegment
BinaryIDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset)
if (rows != 1) {
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
}
auto default_type = index_->metric_type;
if (config.contains(Metric::TYPE)) {
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
}
std::vector<faiss::RangeSearchPartialResult*> res;
DynamicResultSegment result;
auto radius = config[IndexParams::range_search_radius].get<int>();
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
? config[IndexParams::range_search_buffer_size].get<size_t>()
: 16384;
auto real_idx = dynamic_cast<faiss::IndexBinaryFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexBinaryFlat type!");
}
real_idx->range_search(rows, reinterpret_cast<const uint8_t*>(p_data), radius, res, buffer_size, bitset);
ExchangeDataset(result, res);
MapUids(result);
index_->metric_type = default_type;
return result;
}
int64_t
BinaryIDMAP::Count() {
if (!index_) {

View File

@ -17,6 +17,7 @@
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
namespace milvus {
namespace knowhere {
@ -46,6 +47,9 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView& bitset) override;
DynamicResultSegment
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
int64_t
Count() override;

View File

@ -98,6 +98,41 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::B
return ret_ds;
}
DynamicResultSegment
IDMAP::QueryByDistance(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config,
const faiss::BitsetView& bitset) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GET_TENSOR_DATA(dataset)
if (rows != 1) {
KNOWHERE_THROW_MSG("QueryByDistance only accept nq = 1!");
}
auto default_type = index_->metric_type;
if (config.contains(Metric::TYPE)) {
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
}
std::vector<faiss::RangeSearchPartialResult*> res;
DynamicResultSegment result;
auto radius = config[IndexParams::range_search_radius].get<float>();
auto buffer_size = config.contains(IndexParams::range_search_buffer_size)
? config[IndexParams::range_search_buffer_size].get<size_t>()
: 16384;
auto real_idx = dynamic_cast<faiss::IndexFlat*>(index_.get());
if (real_idx == nullptr) {
KNOWHERE_THROW_MSG("Cannot dynamic_cast the index to faiss::IndexFlat type!");
}
if (index_->metric_type == faiss::MetricType::METRIC_L2) {
radius *= radius;
}
real_idx->range_search(rows, reinterpret_cast<const float*>(p_data), radius, res, buffer_size, bitset);
ExchangeDataset(result, res);
MapUids(result);
index_->metric_type = default_type;
return result;
}
int64_t
IDMAP::Count() {
if (!index_) {

View File

@ -16,6 +16,7 @@
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
namespace milvus {
namespace knowhere {
@ -45,6 +46,9 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView&) override;
DynamicResultSegment
QueryByDistance(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView& bitset);
int64_t
Count() override;

View File

@ -24,6 +24,7 @@
#include "knowhere/index/Index.h"
#include "knowhere/index/IndexType.h"
#include "knowhere/index/vector_index/Statistics.h"
#include "knowhere/index/vector_index/helpers/DynamicResultSet.h"
namespace milvus {
namespace knowhere {
@ -94,6 +95,23 @@ class VecIndex : public Index {
}
}
void
MapUids(DynamicResultSegment& milvus_dataset) {
if (uids_) {
for (auto& mrspr : milvus_dataset) {
for (auto j = 0; j < mrspr->buffers.size(); ++j) {
auto buf = mrspr->buffers[j];
auto len = j + 1 == mrspr->buffers.size() ? mrspr->wp : mrspr->buffer_size;
for (auto i = 0; i < len; ++i) {
if (buf.ids[i] >= 0) {
buf.ids[i] = uids_->at(buf.ids[i]);
}
}
}
}
}
}
size_t
UidsSize() {
return uids_ ? uids_->size() * sizeof(IDType) : 0;

View File

@ -0,0 +1,194 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <src/index/knowhere/knowhere/common/Exception.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.h>
#include <src/index/thirdparty/faiss/impl/AuxIndexStructures.h>
#include <cstring>
#include <iostream>
#include <string>
#include <utility>
namespace milvus {
namespace knowhere {
/***********************************************************************
* DynamicResultSet
***********************************************************************/
void
DynamicResultSet::do_alloction() {
if (count <= 0) {
KNOWHERE_THROW_MSG("DynamicResultSet::do_alloction failed because of count <= 0");
}
labels = std::shared_ptr<idx_t[]>(new idx_t[count], std::default_delete<idx_t[]>());
distances = std::shared_ptr<float[]>(new float[count], std::default_delete<float[]>());
// labels = std::make_shared<idx_t []>(new idx_t[count], std::default_delete<idx_t[]>());
// distances = std::make_shared<float []>(new float[count], std::default_delete<float[]>());
}
void
DynamicResultSet::do_sort(ResultSetPostProcessType postProcessType) {
if (postProcessType == ResultSetPostProcessType::SortAsc) {
quick_sort<true>(0, count);
} else if (postProcessType == ResultSetPostProcessType::SortDesc) {
quick_sort<false>(0, count);
} else {
KNOWHERE_THROW_MSG("invalid sort type!");
}
}
template <bool asc>
void
DynamicResultSet::quick_sort(size_t lp, size_t rp) {
auto len = rp - lp;
if (len <= 1) {
return;
}
auto pvot = lp + (len >> 1);
size_t low = lp;
size_t high = rp - 1;
auto pids = labels.get();
auto pdis = distances.get();
std::swap(pdis[pvot], pdis[high]);
std::swap(pids[pvot], pids[high]);
if (asc) {
while (low < high) {
while (low < high && pdis[low] <= pdis[high]) {
low++;
}
if (low == high) {
break;
}
std::swap(pdis[low], pdis[high]);
std::swap(pids[low], pids[high]);
high--;
while (low < high && pdis[high] >= pdis[low]) {
high--;
}
if (low == high) {
break;
}
std::swap(pdis[low], pdis[high]);
std::swap(pids[low], pids[high]);
low++;
}
} else {
while (low < high) {
while (low < high && pdis[low] >= pdis[high]) {
low++;
}
if (low == high) {
break;
}
std::swap(pdis[low], pdis[high]);
std::swap(pids[low], pids[high]);
high--;
while (low < high && pdis[high] <= pdis[low]) {
high--;
}
if (low == high) {
break;
}
std::swap(pdis[low], pdis[high]);
std::swap(pids[low], pids[high]);
low++;
}
}
quick_sort<asc>(lp, low);
quick_sort<asc>(low, rp);
}
DynamicResultSet
DynamicResultCollector::Merge(size_t limit, ResultSetPostProcessType postProcessType) {
if (limit <= 0) {
KNOWHERE_THROW_MSG("limit must > 0!");
}
DynamicResultSet ret;
auto seg_num = seg_results.size();
std::vector<size_t> boundaries(seg_num + 1, 0);
#pragma omp parallel for
for (auto i = 0; i < seg_num; ++i) {
for (auto& pseg : seg_results[i]) {
boundaries[i] += (pseg->buffer_size * pseg->buffers.size() - pseg->buffer_size + pseg->wp);
}
}
for (size_t i = 0, ofs = 0; i <= seg_num; ++i) {
auto bn = boundaries[i];
boundaries[i] = ofs;
ofs += bn;
// boundaries[i] += boundaries[i - 1];
}
ret.count = boundaries[seg_num] <= limit ? boundaries[seg_num] : limit;
ret.do_alloction();
// abandon redundancy answers randomly
// abandon strategy: keep the top limit sequentially
int32_t pos = 1;
for (int i = 1; i < boundaries.size(); ++i) {
if (boundaries[i] >= ret.count) {
pos = i;
break;
}
}
pos--; // last segment id
// full copy
#pragma omp parallel for
for (auto i = 0; i < pos - 1; ++i) {
for (auto& pseg : seg_results[i]) {
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
pseg->copy_range(0, len, ret.labels.get() + boundaries[i], ret.distances.get() + boundaries[i]);
boundaries[i] += len;
}
}
// partial copy
auto last_len = ret.count - boundaries[pos];
for (auto& pseg : seg_results[pos]) {
auto len = pseg->buffers.size() * pseg->buffer_size - pseg->buffer_size + pseg->wp;
auto ncopy = last_len > len ? len : last_len;
pseg->copy_range(0, ncopy, ret.labels.get() + boundaries[pos], ret.distances.get() + boundaries[pos]);
boundaries[pos] += ncopy;
last_len -= ncopy;
if (last_len <= 0) {
break;
}
}
if (postProcessType != ResultSetPostProcessType::None) {
ret.do_sort(postProcessType);
}
return ret;
}
void
DynamicResultCollector::Append(milvus::knowhere::DynamicResultSegment&& seg_result) {
seg_results.push_back(std::move(seg_result));
}
void
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset) {
for (auto& prspr : faiss_dataset) {
auto mrspr = new DynamicResultFragment(prspr->res->buffer_size);
mrspr->wp = prspr->wp;
mrspr->buffers.resize(prspr->buffers.size());
for (auto i = 0; i < prspr->buffers.size(); ++i) {
mrspr->buffers[i].ids = prspr->buffers[i].ids;
mrspr->buffers[i].dis = prspr->buffers[i].dis;
prspr->buffers[i].ids = nullptr;
prspr->buffers[i].dis = nullptr;
}
delete prspr->res;
milvus_dataset.push_back(mrspr);
}
}
} // namespace knowhere
} // namespace milvus

View File

@ -0,0 +1,60 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "faiss/impl/AuxIndexStructures.h"
#include "knowhere/common/Typedef.h"
namespace milvus {
namespace knowhere {
enum class ResultSetPostProcessType { None = 0, SortDesc, SortAsc };
using idx_t = int64_t;
struct DynamicResultSet {
std::shared_ptr<idx_t[]> labels; /// result for query i is labels[lims[i]:lims[i + 1]]
std::shared_ptr<float[]> distances; /// corresponding distances, not sorted
size_t count; /// size of the result buffer's size, when reaches this size, auto start a new buffer
void
do_alloction();
void
do_sort(ResultSetPostProcessType postProcessType = ResultSetPostProcessType::SortAsc);
private:
template <bool asc>
void
quick_sort(size_t lp, size_t rp);
};
typedef faiss::BufferList DynamicResultFragment;
typedef DynamicResultFragment* DynamicResultFragmentPtr;
typedef std::vector<DynamicResultFragmentPtr> DynamicResultSegment;
struct DynamicResultCollector {
public:
DynamicResultSet
Merge(size_t limit = 10000, ResultSetPostProcessType postProcessType = ResultSetPostProcessType::None);
void
Append(DynamicResultSegment&& seg_result);
private:
std::vector<DynamicResultSegment> seg_results; /// unmerged results of every segments
};
void
ExchangeDataset(DynamicResultSegment& milvus_dataset, std::vector<faiss::RangeSearchPartialResult*>& faiss_dataset);
} // namespace knowhere
} // namespace milvus

View File

@ -28,6 +28,10 @@ constexpr const char* DEVICEID = "gpu_id";
}; // namespace meta
namespace IndexParams {
// Range Search Params
constexpr const char* range_search_radius = "range_search_radius";
constexpr const char* range_search_buffer_size = "range_search_buffer_size";
// IVF Params
constexpr const char* nprobe = "nprobe";
constexpr const char* nlist = "nlist";

View File

@ -126,7 +126,17 @@ void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
RangeSearchResult *result,
const BitsetView& bitset) const
{
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
FAISS_THROW_MSG("This interface is abandoned yet.");
}
void IndexBinaryFlat::range_search(faiss::IndexBinary::idx_t n,
const uint8_t* x,
int radius,
std::vector<faiss::RangeSearchPartialResult*>& result,
size_t buffer_size,
const faiss::BitsetView& bitset)
{
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result, buffer_size, bitset);
}
} // namespace faiss

View File

@ -13,6 +13,7 @@
#include <vector>
#include <faiss/IndexBinary.h>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
@ -45,6 +46,11 @@ struct IndexBinaryFlat : IndexBinary {
RangeSearchResult *result,
const BitsetView& bitset = nullptr) const override;
void range_search(idx_t n, const uint8_t *x, int radius,
std::vector<RangeSearchPartialResult*> &result,
size_t buffer_size,
const BitsetView& bitset = nullptr); // const override
void reconstruct(idx_t key, uint8_t *recons) const override;
/** Remove some ids. Note that because of the indexing structure,

View File

@ -93,19 +93,28 @@ void IndexFlat::range_search (idx_t n, const float *x, float radius,
RangeSearchResult *result,
const BitsetView& bitset) const
{
switch (metric_type) {
case METRIC_INNER_PRODUCT:
range_search_inner_product (x, xb.data(), d, n, ntotal,
radius, result);
break;
case METRIC_L2:
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
break;
default:
FAISS_THROW_MSG("metric type not supported");
}
FAISS_THROW_MSG("This interface is abandoned yet.");
}
void IndexFlat::range_search(faiss::Index::idx_t n,
const float* x,
float radius,
std::vector<faiss::RangeSearchPartialResult*>& result,
size_t buffer_size,
const faiss::BitsetView& bitset) {
switch (metric_type) {
case METRIC_INNER_PRODUCT:
range_search_inner_product (x, xb.data(), d, n, ntotal,
radius, result, buffer_size, bitset);
break;
case METRIC_L2:
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result, buffer_size, bitset);
break;
default:
FAISS_THROW_MSG("metric type not supported");
}
}
void IndexFlat::compute_distance_subset (
idx_t n,

View File

@ -13,6 +13,7 @@
#include <vector>
#include <faiss/Index.h>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
@ -50,6 +51,14 @@ struct IndexFlat: Index {
RangeSearchResult* result,
const BitsetView& bitset = nullptr) const override;
void range_search(
idx_t n,
const float* x,
float radius,
std::vector<RangeSearchPartialResult*> &result,
size_t buffer_size,
const BitsetView& bitset = nullptr); // const override
void reconstruct(idx_t key, float* recons) const override;
/** compute distance with a subset of vectors

View File

@ -48,9 +48,12 @@ void RangeSearchResult::do_allocation () {
}
RangeSearchResult::~RangeSearchResult () {
delete [] labels;
delete [] distances;
delete [] lims;
if (labels)
delete [] labels;
if (distances)
delete [] distances;
if (lims)
delete [] lims;
}
@ -71,8 +74,10 @@ BufferList::BufferList (size_t buffer_size):
BufferList::~BufferList ()
{
for (int i = 0; i < buffers.size(); i++) {
delete [] buffers[i].ids;
delete [] buffers[i].dis;
if (buffers[i].ids)
delete [] buffers[i].ids;
if (buffers[i].dis)
delete [] buffers[i].dis;
}
}

View File

@ -807,13 +807,15 @@ void knn_L2sqr_by_idx (const float * x,
/** Find the nearest neighbors for nx queries in a set of ny vectors
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
*/
template <bool compute_l2>
template <bool compute_l2>
static void range_search_blas (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *result)
std::vector<RangeSearchPartialResult*> &res,
size_t buffer_size,
const BitsetView &bitset)
{
// BLAS does not like empty matrices
@ -837,13 +839,13 @@ static void range_search_blas (
fvec_norms_L2sqr (y_norms, y, d, ny);
}
std::vector <RangeSearchPartialResult *> partial_results;
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny) j1 = ny;
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
partial_results.push_back (pres);
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
tmp_res->buffer_size = buffer_size;
RangeSearchPartialResult * pres = new RangeSearchPartialResult (tmp_res);
res.push_back (pres);
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
@ -867,14 +869,16 @@ static void range_search_blas (
for (size_t j = j0; j < j1; j++) {
float ip = *ip_line++;
if (compute_l2) {
float dis = x_norms[i] + y_norms[j] - 2 * ip;
if (dis < radius) {
qres.add (dis, j);
}
} else {
if (ip > radius) {
qres.add (ip, j);
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
if (compute_l2) {
float dis = x_norms[i] + y_norms[j] - 2 * ip;
if (dis < radius) {
qres.add (dis, j);
}
} else {
if (ip > radius) {
qres.add (ip, j);
}
}
}
}
@ -883,7 +887,7 @@ static void range_search_blas (
InterruptCallback::check ();
}
RangeSearchPartialResult::merge (partial_results);
// RangeSearchPartialResult::merge (partial_results);
}
@ -892,12 +896,16 @@ static void range_search_sse (const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *res)
std::vector<RangeSearchPartialResult*> &res,
size_t buffer_size,
const BitsetView &bitset)
{
#pragma omp parallel
{
RangeSearchPartialResult pres (res);
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
tmp_res->buffer_size = buffer_size;
auto pres = new RangeSearchPartialResult(tmp_res);
#pragma omp for
for (size_t i = 0; i < nx; i++) {
@ -905,9 +913,60 @@ static void range_search_sse (const float * x,
const float * y_ = y;
size_t j;
RangeQueryResult & qres = pres.new_result (i);
RangeQueryResult & qres = pres->new_result (i);
for (j = 0; j < ny; j++) {
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
if (compute_l2) {
float disij = fvec_L2sqr (x_, y_, d);
if (disij < radius) {
qres.add (disij, j);
}
} else {
float ip = fvec_inner_product (x_, y_, d);
if (ip > radius) {
qres.add (ip, j);
}
}
}
y_ += d;
}
}
#pragma omp critical
res.push_back(pres);
}
// check just at the end because the use case is typically just
// when the nb of queries is low.
InterruptCallback::check();
}
// range search by sse when nq = 1, namely single query situation
template <bool compute_l2>
static void range_search_sse_sq (const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
std::vector<RangeSearchPartialResult*> &res,
size_t buffer_size,
const BitsetView &bitset)
{
#pragma omp parallel
{
RangeSearchResult *tmp_res = new RangeSearchResult(nx);
tmp_res->buffer_size = buffer_size;
auto pres = new RangeSearchPartialResult(tmp_res);
const float * x_ = x;
size_t j;
RangeQueryResult & qres = pres->new_result (0);
#pragma omp for
for (j = 0; j < ny; j++) {
const float * y_ = y + j * d;
if (bitset.empty() || !bitset.test((faiss::ConcurrentBitset::id_type_t)(j))) {
if (compute_l2) {
float disij = fvec_L2sqr (x_, y_, d);
if (disij < radius) {
@ -919,11 +978,10 @@ static void range_search_sse (const float * x,
qres.add (ip, j);
}
}
y_ += d;
}
}
pres.finalize ();
#pragma omp critical
res.push_back(pres);
}
// check just at the end because the use case is typically just
@ -932,21 +990,24 @@ static void range_search_sse (const float * x,
}
void range_search_L2sqr (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *res)
std::vector<RangeSearchPartialResult*> &res,
size_t buffer_size,
const BitsetView &bitset)
{
if (nx < distance_compute_blas_threshold) {
range_search_sse<true> (x, y, d, nx, ny, radius, res);
if (nx == 1) {
range_search_sse_sq<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
} else {
range_search_sse<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
}
} else {
range_search_blas<true> (x, y, d, nx, ny, radius, res);
range_search_blas<true> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
}
}
@ -955,17 +1016,21 @@ void range_search_inner_product (
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *res)
std::vector<RangeSearchPartialResult*> &res,
size_t buffer_size,
const BitsetView &bitset)
{
if (nx < distance_compute_blas_threshold) {
range_search_sse<false> (x, y, d, nx, ny, radius, res);
if (nx == 1)
range_search_sse_sq<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
else
range_search_sse<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
} else {
range_search_blas<false> (x, y, d, nx, ny, radius, res);
range_search_blas<false> (x, y, d, nx, ny, radius, res, buffer_size, bitset);
}
}
void pairwise_L2sqr (int64_t d,
int64_t nq, const float *xq,
int64_t nb, const float *xb,

View File

@ -17,6 +17,7 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/ConcurrentBitset.h>
#include <faiss/utils/BitsetView.h>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
@ -241,15 +242,19 @@ void range_search_L2sqr (
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *result);
std::vector<RangeSearchPartialResult*> &result,
size_t buffer_size,
const BitsetView &bitset = nullptr);
/// same as range_search_L2sqr for the inner product similarity
void range_search_inner_product (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *result);
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
std::vector<RangeSearchPartialResult*> &result,
size_t buffer_size,
const BitsetView &bitset = nullptr);
/***************************************************************************

View File

@ -744,6 +744,7 @@ void hammings_knn_mc(
}
}
}
template <class HammingComputer>
static
void hamming_range_search_template (
@ -753,28 +754,33 @@ void hamming_range_search_template (
size_t nb,
int radius,
size_t code_size,
RangeSearchResult *res)
std::vector<RangeSearchPartialResult*> &result,
size_t buffer_size,
const BitsetView &bitset)
{
#pragma omp parallel
{
RangeSearchPartialResult pres (res);
RangeSearchResult *tmp_res = new RangeSearchResult(na);
tmp_res->buffer_size = buffer_size;
auto pres = new RangeSearchPartialResult(tmp_res);
HammingComputer hc (a, code_size);
const uint8_t * yi = b;
RangeQueryResult & qres = pres->new_result (0);
#pragma omp for
for (size_t i = 0; i < na; i++) {
HammingComputer hc (a + i * code_size, code_size);
const uint8_t * yi = b;
RangeQueryResult & qres = pres.new_result (i);
for (size_t j = 0; j < nb; j++) {
int dis = hc.hamming (yi);
for (size_t j = 0; j < nb; j++) {
if (bitset.empty() || !bitset.test((ConcurrentBitset::id_type_t)j)) {
int dis = hc.hamming (yi + j * code_size);
if (dis < radius) {
qres.add(dis, j);
}
yi += code_size;
}
// yi += code_size;
}
pres.finalize ();
#pragma omp critical
result.push_back(pres);
}
}
@ -785,10 +791,12 @@ void hamming_range_search (
size_t nb,
int radius,
size_t code_size,
RangeSearchResult *result)
std::vector<faiss::RangeSearchPartialResult*>& result,
size_t buffer_size,
const BitsetView& bitset)
{
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result)
#define HC(name) hamming_range_search_template<name> (a, b, na, nb, radius, code_size, result, buffer_size, bitset)
switch(code_size) {
case 4: HC(HammingComputer4); break;
@ -805,8 +813,6 @@ void hamming_range_search (
#undef HC
}
/* Count number of matches given a max threshold */
void hamming_count_thres (
const uint8_t * bs1,

View File

@ -30,6 +30,7 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/ConcurrentBitset.h>
#include <faiss/utils/BitsetView.h>
#include <faiss/impl/AuxIndexStructures.h>
/* The Hamming distance type */
typedef int32_t hamdis_t;
@ -192,8 +193,9 @@ void hamming_range_search (
size_t nb,
int radius,
size_t ncodes,
RangeSearchResult *result);
std::vector<faiss::RangeSearchPartialResult*>& result,
size_t buffer_size,
const BitsetView& bitset);
/* Counting the number of matches or of cross-matches (without returning them)
For use with function that assume pre-allocated memory */

View File

@ -33,6 +33,7 @@ set(util_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/DynamicResultSet.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/Statistics.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
@ -39,6 +40,7 @@ class BinaryIDMAPTest : public DataGen, public TestWithParam<std::string> {
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
/*
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
ASSERT_TRUE(!xb_bin.empty());
@ -157,3 +159,89 @@ TEST_P(BinaryIDMAPTest, binaryidmap_slice) {
// PrintResult(result, nq, k);
}
}
*/
TEST_P(BinaryIDMAPTest, binaryidmap_range_search) {
std::string MetricType = GetParam();
milvus::knowhere::Config conf{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::range_search_radius, radius},
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
{milvus::knowhere::Metric::TYPE, MetricType},
};
int hamming_radius = 10;
auto hamming_dis = [] (const int64_t *pa, const int64_t *pb) -> int {
return __builtin_popcountl((*pa) ^ (*pb));
};
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
std::vector<size_t> bf_cnt(nq, 0);
auto bruteforce = [&] () {
for (auto i = 0; i < nq; ++ i) {
const int64_t *pq = reinterpret_cast<int64_t*>(xq_bin.data()) + i;
const int64_t *pb = reinterpret_cast<int64_t*>(xb_bin.data());
for (auto j = 0; j < nb; ++ j) {
auto dist = hamming_dis(pq, pb + j);
if (dist < hamming_radius) {
idmap[i][j] = true;
bf_cnt[i] ++;
}
}
}
};
bruteforce();
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
for (auto i = 0; i < nq; ++ i) {
int query_i_cnt = 0;
int correct_cnt = 0;
for (auto &res_space: results[i]) {
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
for (auto j = 0; j < qnr; ++ j) {
auto bno = j / res_space->buffer_size;
auto pos = j % res_space->buffer_size;
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
correct_cnt ++;
}
}
}
ASSERT_EQ(correct_cnt, bf_cnt[i]);
}
};
{
// serialize index
index_->Train(base_dataset, conf);
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
std::vector<milvus::knowhere::DynamicResultSegment> results;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(results);
//
auto binaryset = index_->Serialize(conf);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
{
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq_bin.data() + i * dim / 8);
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(rresults);
}
}
}

View File

@ -15,6 +15,7 @@
#include <fiu/fiu-local.h>
#include <iostream>
#include <thread>
#include <src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/IndexType.h"
@ -209,6 +210,240 @@ TEST_P(IDMAPTest, idmap_slice) {
}
}
TEST_P(IDMAPTest, idmap_range_search_l2) {
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::range_search_radius, radius},
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
float ret = 0;
for (auto i = 0; i < dim; ++ i) {
auto dif = (pa[i] - pb[i]);
ret += dif * dif;
}
return ret;
};
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
std::vector<size_t> bf_cnt(nq, 0);
auto bruteforce = [&] () {
auto rds = radius * radius;
for (auto i = 0; i < nq; ++ i) {
const float *pq = xq.data() + i * dim;
for (auto j = 0; j < nb; ++ j) {
const float *pb = xb.data() + j * dim;
auto dist = l2dis(pq, pb, dim);
if (dist < rds) {
idmap[i][j] = true;
bf_cnt[i] ++;
}
}
}
};
bruteforce();
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
{ // compare the result
// std::cout << "size of result: " << results.size() << std::endl;
for (auto i = 0; i < nq; ++ i) {
int correct_cnt = 0;
// std::cout << "query id = " << i << ", result[i].size = " << results[i].size() << std::endl;
for (auto &res_space: results[i]) {
// std::cout << "buffer size = " << res_space->buffer_size << ", wp = " << res_space->wp << ", size of buffers = " << res_space->buffers.size() << std::endl;
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
// std::cout << "qnr = " << qnr << std::endl;
for (auto j = 0; j < qnr; ++ j) {
auto bno = j / res_space->buffer_size;
auto pos = j % res_space->buffer_size;
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
correct_cnt ++;
}
}
}
ASSERT_EQ(correct_cnt, bf_cnt[i]);
}
}
};
{
index_->Train(base_dataset, conf);
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
std::vector<milvus::knowhere::DynamicResultSegment> results;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(results);
auto binaryset = index_->Serialize(conf);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
{ // query again and compare the result
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(rresults);
}
}
}
TEST_P(IDMAPTest, idmap_range_search_ip) {
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::range_search_radius, radius},
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::IP}};
auto ipdis = [](const float *pa, const float *pb, size_t dim) -> float {
float ret = 0;
for (auto i = 0; i < dim; ++ i) {
ret += pa[i] * pb[i];
}
return ret;
};
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
std::vector<size_t> bf_cnt(nq, 0);
auto bruteforce = [&] () {
for (auto i = 0; i < nq; ++ i) {
const float *pq = xq.data() + i * dim;
for (auto j = 0; j < nb; ++ j) {
const float *pb = xb.data() + j * dim;
auto dist = ipdis(pq, pb, dim);
if (dist > radius) {
idmap[i][j] = true;
bf_cnt[i] ++;
}
}
}
};
bruteforce();
auto compare_res = [&] (std::vector<milvus::knowhere::DynamicResultSegment> &results) {
{ // compare the result
for (auto i = 0; i < nq; ++ i) {
int correct_cnt = 0;
for (auto &res_space: results[i]) {
auto qnr = res_space->buffer_size * res_space->buffers.size() - res_space->buffer_size + res_space->wp;
for (auto j = 0; j < qnr; ++ j) {
auto bno = j / res_space->buffer_size;
auto pos = j % res_space->buffer_size;
ASSERT_EQ(idmap[i][res_space->buffers[bno].ids[pos]], true);
if (idmap[i][res_space->buffers[bno].ids[pos]]) {
correct_cnt ++;
}
}
}
ASSERT_EQ(correct_cnt, bf_cnt[i]);
}
}
};
{
index_->Train(base_dataset, conf);
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
std::vector<milvus::knowhere::DynamicResultSegment> results;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
results.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(results);
auto binaryset = index_->Serialize(conf);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
{ // query again and compare the result
std::vector<milvus::knowhere::DynamicResultSegment> rresults;
for (auto i = 0; i < nq; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data() + i * dim);
rresults.push_back(index_->QueryByDistance(qd, conf, nullptr));
}
compare_res(rresults);
}
}
}
TEST_P(IDMAPTest, idmap_dynamic_result_set) {
milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::IndexParams::range_search_radius, radius},
{milvus::knowhere::IndexParams::range_search_buffer_size, buffer_size},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}};
auto l2dis = [](const float *pa, const float *pb, size_t dim) -> float {
float ret = 0;
for (auto i = 0; i < dim; ++ i) {
auto dif = (pa[i] - pb[i]);
ret += dif * dif;
}
return ret;
};
std::vector<std::vector<bool>> idmap(nq, std::vector<bool>(nb, false));
std::vector<size_t> bf_cnt(nq, 0);
auto bruteforce = [&] () {
auto rds = radius * radius;
for (auto i = 0; i < nq; ++ i) {
const float *pq = xq.data() + i * dim;
for (auto j = 0; j < nb; ++ j) {
const float *pb = xb.data() + j * dim;
auto dist = l2dis(pq, pb, dim);
if (dist < rds) {
idmap[i][j] = true;
bf_cnt[i] ++;
}
}
}
};
bruteforce();
auto check_rst = [&] (milvus::knowhere::DynamicResultSet &rst, milvus::knowhere::ResultSetPostProcessType rspt) {
{ // compare the result
for (auto i = 0; i < rst.count - 1; ++ i) {
if (rspt == milvus::knowhere::ResultSetPostProcessType::SortAsc)
ASSERT_LE(rst.distances.get() + i, rst.distances.get() + i + 1);
else if (rspt == milvus::knowhere::ResultSetPostProcessType::SortDesc)
ASSERT_GE(rst.distances.get() + i, rst.distances.get() + i + 1);
}
}
};
{
milvus::knowhere::DynamicResultCollector collector;
index_->Train(base_dataset, conf);
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
for (auto i = 0; i < 3; ++ i) {
auto qd = milvus::knowhere::GenDataset(1, dim, xq.data());
collector.Append(index_->QueryByDistance(qd, conf, nullptr));
}
auto rst = collector.Merge(1000, milvus::knowhere::ResultSetPostProcessType::SortAsc);
ASSERT_LE(rst.count, 1000);
check_rst(rst, milvus::knowhere::ResultSetPostProcessType::SortAsc);
}
}
#ifdef MILVUS_GPU_VERSION
TEST_P(IDMAPTest, idmap_copy) {
ASSERT_TRUE(!xb.empty());

View File

@ -41,6 +41,8 @@ class DataGen {
int nq = 10;
int dim = 64;
int k = 10;
int buffer_size = 16384;
float radius = 2.8;
std::vector<float> xb;
std::vector<float> xq;
std::vector<uint8_t> xb_bin;