enhance: sparse float vector to support brute force iterator and range search (#32635)

issue: #29419

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
pull/32704/head
Buqian Zheng 2024-04-29 14:35:26 +08:00 committed by GitHub
parent 083bd38c77
commit 858599d831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 28 deletions

View File

@ -79,30 +79,16 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
// TODO(SPARSE): support sparse brute force range search
AssertInfo(
!search_cfg.contains(RADIUS) && !search_cfg.contains(RANGE_FILTER),
"sparse vector not support range search");
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
auto stat = knowhere::BruteForce::SearchSparseWithBuf(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset);
milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf");
if (stat != knowhere::Status::success) {
throw SegcoreError(KnowhereError, KnowhereStatusString(stat));
}
} else if (search_cfg.contains(RADIUS)) {
if (search_cfg.contains(RADIUS)) {
if (search_cfg.contains(RANGE_FILTER)) {
CheckRangeSearchParam(search_cfg[RADIUS],
search_cfg[RANGE_FILTER],
@ -121,6 +107,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
} else if (data_type == DataType::VECTOR_BINARY) {
res = knowhere::BruteForce::RangeSearch<uint8_t>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
res = knowhere::BruteForce::RangeSearch<
knowhere::sparse::SparseRow<float>>(
base_dataset, query_dataset, search_cfg, bitset);
} else {
PanicInfo(
ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force range search:{}",
data_type);
}
milvus::tracer::AddEvent("knowhere_finish_BruteForce_RangeSearch");
if (!res.has_value()) {
@ -170,6 +165,18 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
sub_result.mutable_distances().data(),
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
stat = knowhere::BruteForce::SearchSparseWithBuf(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset);
} else {
PanicInfo(ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force search:{}",
data_type);
}
milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf");
if (stat != knowhere::Status::success) {
@ -193,6 +200,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
auto dim = dataset.dim;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
knowhere::expected<
@ -211,8 +222,12 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
iterators_val = knowhere::BruteForce::AnnIterator<bfloat16>(
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_SPARSE_FLOAT:
iterators_val = knowhere::BruteForce::AnnIterator<
knowhere::sparse::SparseRow<float>>(
base_dataset, query_dataset, search_cfg, bitset);
break;
default:
// TODO(SPARSE): support sparse brute force iterator
PanicInfo(ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force iterator:{}",
data_type);
@ -240,4 +255,4 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
}
}
} // namespace milvus::query
} // namespace milvus::query

View File

@ -26,11 +26,10 @@ using namespace milvus::query;
namespace {
std::vector<int>
Ref(const knowhere::sparse::SparseRow<float>* base,
const knowhere::sparse::SparseRow<float>& query,
int nb,
int topk,
const knowhere::MetricType& metric) {
SearchRef(const knowhere::sparse::SparseRow<float>* base,
const knowhere::sparse::SparseRow<float>& query,
int nb,
int topk) {
std::vector<std::tuple<float, int>> res;
for (int i = 0; i < nb; i++) {
auto& row = base[i];
@ -50,6 +49,31 @@ Ref(const knowhere::sparse::SparseRow<float>* base,
return offsets;
}
std::vector<int>
RangeSearchRef(const knowhere::sparse::SparseRow<float>* base,
const knowhere::sparse::SparseRow<float>& query,
int nb,
float radius,
float range_filter,
int topk) {
std::vector<int> offsets;
for (int i = 0; i < nb; i++) {
auto& row = base[i];
auto distance = row.dot(query);
if (distance <= range_filter && distance > radius) {
offsets.push_back(i);
}
}
// select and sort top k on the range filter side
std::sort(offsets.begin(), offsets.end(), [&](int a, int b) {
return base[a].dot(query) > base[b].dot(query);
});
if (offsets.size() > topk) {
offsets.resize(topk);
}
return offsets;
}
void
AssertMatch(const std::vector<int>& expected, const int64_t* actual) {
for (int i = 0; i < expected.size(); i++) {
@ -95,11 +119,45 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
auto ref =
Ref(base.get(), *(query.get() + i), nb, topk, metric_type);
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
auto ans = result.get_seg_offsets() + i * topk;
AssertMatch(ref, ans);
}
search_info.search_params_[RADIUS] = 0.1;
search_info.search_params_[RANGE_FILTER] = 0.5;
auto result2 = BruteForceSearch(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
auto ref = RangeSearchRef(
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);
auto ans = result2.get_seg_offsets() + i * topk;
AssertMatch(ref, ans);
}
auto result3 = BruteForceSearchIterators(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators();
for (int i = 0; i < nq; i++) {
auto it = iterators[i];
auto q = *(query.get() + i);
auto last_dis = std::numeric_limits<float>::max();
// we should see strict decreasing distances for brute force iterator.
while (it->HasNext()) {
auto [offset, dis] = it->Next();
ASSERT_LE(dis, last_dis);
last_dis = dis;
ASSERT_FLOAT_EQ(dis, base[offset].dot(q));
}
}
}
};