From 858599d831dc9ecf9502fcab5980169912f65719 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Mon, 29 Apr 2024 14:35:26 +0800 Subject: [PATCH] enhance: sparse float vector to support brute force iterator and range search (#32635) issue: #29419 Signed-off-by: Buqian Zheng --- internal/core/src/query/SearchBruteForce.cpp | 57 ++++++++++------ internal/core/unittest/test_bf_sparse.cpp | 72 ++++++++++++++++++-- 2 files changed, 101 insertions(+), 28 deletions(-) diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 186846bd6f..d6478acc82 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -79,30 +79,16 @@ BruteForceSearch(const dataset::SearchDataset& dataset, auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); + if (data_type == DataType::VECTOR_SPARSE_FLOAT) { + base_dataset->SetIsSparse(true); + query_dataset->SetIsSparse(true); + } auto search_cfg = PrepareBFSearchParams(search_info); sub_result.mutable_seg_offsets().resize(nq * topk); sub_result.mutable_distances().resize(nq * topk); - if (data_type == DataType::VECTOR_SPARSE_FLOAT) { - // TODO(SPARSE): support sparse brute force range search - AssertInfo( - !search_cfg.contains(RADIUS) && !search_cfg.contains(RANGE_FILTER), - "sparse vector not support range search"); - base_dataset->SetIsSparse(true); - query_dataset->SetIsSparse(true); - auto stat = knowhere::BruteForce::SearchSparseWithBuf( - base_dataset, - query_dataset, - sub_result.mutable_seg_offsets().data(), - sub_result.mutable_distances().data(), - search_cfg, - bitset); - milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf"); - if (stat != knowhere::Status::success) { - throw SegcoreError(KnowhereError, KnowhereStatusString(stat)); - } - } else if (search_cfg.contains(RADIUS)) { + if (search_cfg.contains(RADIUS)) { if (search_cfg.contains(RANGE_FILTER)) { CheckRangeSearchParam(search_cfg[RADIUS], search_cfg[RANGE_FILTER], @@ -121,6 +107,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset, } else if (data_type == DataType::VECTOR_BINARY) { res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); + } else if (data_type == DataType::VECTOR_SPARSE_FLOAT) { + res = knowhere::BruteForce::RangeSearch< + knowhere::sparse::SparseRow>( + base_dataset, query_dataset, search_cfg, bitset); + } else { + PanicInfo( + ErrorCode::Unsupported, + "Unsupported dataType for chunk brute force range search:{}", + data_type); } milvus::tracer::AddEvent("knowhere_finish_BruteForce_RangeSearch"); if (!res.has_value()) { @@ -170,6 +165,18 @@ BruteForceSearch(const dataset::SearchDataset& dataset, sub_result.mutable_distances().data(), search_cfg, bitset); + } else if (data_type == DataType::VECTOR_SPARSE_FLOAT) { + stat = knowhere::BruteForce::SearchSparseWithBuf( + base_dataset, + query_dataset, + sub_result.mutable_seg_offsets().data(), + sub_result.mutable_distances().data(), + search_cfg, + bitset); + } else { + PanicInfo(ErrorCode::Unsupported, + "Unsupported dataType for chunk brute force search:{}", + data_type); } milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf"); if (stat != knowhere::Status::success) { @@ -193,6 +200,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, auto dim = dataset.dim; auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); + if (data_type == DataType::VECTOR_SPARSE_FLOAT) { + base_dataset->SetIsSparse(true); + query_dataset->SetIsSparse(true); + } auto search_cfg = PrepareBFSearchParams(search_info); knowhere::expected< @@ -211,8 +222,12 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, iterators_val = knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, search_cfg, bitset); break; + case DataType::VECTOR_SPARSE_FLOAT: + iterators_val = knowhere::BruteForce::AnnIterator< + knowhere::sparse::SparseRow>( + base_dataset, query_dataset, search_cfg, bitset); + break; default: - // TODO(SPARSE): support sparse brute force iterator PanicInfo(ErrorCode::Unsupported, "Unsupported dataType for chunk brute force iterator:{}", data_type); @@ -240,4 +255,4 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, } } -} // namespace milvus::query \ No newline at end of file +} // namespace milvus::query diff --git a/internal/core/unittest/test_bf_sparse.cpp b/internal/core/unittest/test_bf_sparse.cpp index 0e970c48e2..7c9e466208 100644 --- a/internal/core/unittest/test_bf_sparse.cpp +++ b/internal/core/unittest/test_bf_sparse.cpp @@ -26,11 +26,10 @@ using namespace milvus::query; namespace { std::vector -Ref(const knowhere::sparse::SparseRow* base, - const knowhere::sparse::SparseRow& query, - int nb, - int topk, - const knowhere::MetricType& metric) { +SearchRef(const knowhere::sparse::SparseRow* base, + const knowhere::sparse::SparseRow& query, + int nb, + int topk) { std::vector> res; for (int i = 0; i < nb; i++) { auto& row = base[i]; @@ -50,6 +49,31 @@ Ref(const knowhere::sparse::SparseRow* base, return offsets; } +std::vector +RangeSearchRef(const knowhere::sparse::SparseRow* base, + const knowhere::sparse::SparseRow& query, + int nb, + float radius, + float range_filter, + int topk) { + std::vector offsets; + for (int i = 0; i < nb; i++) { + auto& row = base[i]; + auto distance = row.dot(query); + if (distance <= range_filter && distance > radius) { + offsets.push_back(i); + } + } + // select and sort top k on the range filter side + std::sort(offsets.begin(), offsets.end(), [&](int a, int b) { + return base[a].dot(query) > base[b].dot(query); + }); + if (offsets.size() > topk) { + offsets.resize(topk); + } + return offsets; +} + void AssertMatch(const std::vector& expected, const int64_t* actual) { for (int i = 0; i < expected.size(); i++) { @@ -95,11 +119,45 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { bitset_view, DataType::VECTOR_SPARSE_FLOAT); for (int i = 0; i < nq; i++) { - auto ref = - Ref(base.get(), *(query.get() + i), nb, topk, metric_type); + auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk); auto ans = result.get_seg_offsets() + i * topk; AssertMatch(ref, ans); } + + search_info.search_params_[RADIUS] = 0.1; + search_info.search_params_[RANGE_FILTER] = 0.5; + auto result2 = BruteForceSearch(dataset, + base.get(), + nb, + search_info, + bitset_view, + DataType::VECTOR_SPARSE_FLOAT); + for (int i = 0; i < nq; i++) { + auto ref = RangeSearchRef( + base.get(), *(query.get() + i), nb, 0.1, 0.5, topk); + auto ans = result2.get_seg_offsets() + i * topk; + AssertMatch(ref, ans); + } + + auto result3 = BruteForceSearchIterators(dataset, + base.get(), + nb, + search_info, + bitset_view, + DataType::VECTOR_SPARSE_FLOAT); + auto iterators = result3.chunk_iterators(); + for (int i = 0; i < nq; i++) { + auto it = iterators[i]; + auto q = *(query.get() + i); + auto last_dis = std::numeric_limits::max(); + // we should see strict decreasing distances for brute force iterator. + while (it->HasNext()) { + auto [offset, dis] = it->Next(); + ASSERT_LE(dis, last_dis); + last_dis = dis; + ASSERT_FLOAT_EQ(dis, base[offset].dot(q)); + } + } } };