mirror of https://github.com/milvus-io/milvus.git
enhance: sparse float vector to support brute force iterator and range search (#32635)
issue: #29419 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>pull/32704/head
parent
083bd38c77
commit
858599d831
|
@ -79,30 +79,16 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
|||
|
||||
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
|
||||
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
|
||||
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
}
|
||||
auto search_cfg = PrepareBFSearchParams(search_info);
|
||||
|
||||
sub_result.mutable_seg_offsets().resize(nq * topk);
|
||||
sub_result.mutable_distances().resize(nq * topk);
|
||||
|
||||
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
// TODO(SPARSE): support sparse brute force range search
|
||||
AssertInfo(
|
||||
!search_cfg.contains(RADIUS) && !search_cfg.contains(RANGE_FILTER),
|
||||
"sparse vector not support range search");
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
auto stat = knowhere::BruteForce::SearchSparseWithBuf(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset);
|
||||
milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf");
|
||||
if (stat != knowhere::Status::success) {
|
||||
throw SegcoreError(KnowhereError, KnowhereStatusString(stat));
|
||||
}
|
||||
} else if (search_cfg.contains(RADIUS)) {
|
||||
if (search_cfg.contains(RADIUS)) {
|
||||
if (search_cfg.contains(RANGE_FILTER)) {
|
||||
CheckRangeSearchParam(search_cfg[RADIUS],
|
||||
search_cfg[RANGE_FILTER],
|
||||
|
@ -121,6 +107,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
|||
} else if (data_type == DataType::VECTOR_BINARY) {
|
||||
res = knowhere::BruteForce::RangeSearch<uint8_t>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
res = knowhere::BruteForce::RangeSearch<
|
||||
knowhere::sparse::SparseRow<float>>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
} else {
|
||||
PanicInfo(
|
||||
ErrorCode::Unsupported,
|
||||
"Unsupported dataType for chunk brute force range search:{}",
|
||||
data_type);
|
||||
}
|
||||
milvus::tracer::AddEvent("knowhere_finish_BruteForce_RangeSearch");
|
||||
if (!res.has_value()) {
|
||||
|
@ -170,6 +165,18 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
|||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset);
|
||||
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
stat = knowhere::BruteForce::SearchSparseWithBuf(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset);
|
||||
} else {
|
||||
PanicInfo(ErrorCode::Unsupported,
|
||||
"Unsupported dataType for chunk brute force search:{}",
|
||||
data_type);
|
||||
}
|
||||
milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf");
|
||||
if (stat != knowhere::Status::success) {
|
||||
|
@ -193,6 +200,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
|||
auto dim = dataset.dim;
|
||||
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
|
||||
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
|
||||
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
}
|
||||
auto search_cfg = PrepareBFSearchParams(search_info);
|
||||
|
||||
knowhere::expected<
|
||||
|
@ -211,8 +222,12 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
|||
iterators_val = knowhere::BruteForce::AnnIterator<bfloat16>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
break;
|
||||
case DataType::VECTOR_SPARSE_FLOAT:
|
||||
iterators_val = knowhere::BruteForce::AnnIterator<
|
||||
knowhere::sparse::SparseRow<float>>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
break;
|
||||
default:
|
||||
// TODO(SPARSE): support sparse brute force iterator
|
||||
PanicInfo(ErrorCode::Unsupported,
|
||||
"Unsupported dataType for chunk brute force iterator:{}",
|
||||
data_type);
|
||||
|
@ -240,4 +255,4 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -26,11 +26,10 @@ using namespace milvus::query;
|
|||
namespace {
|
||||
|
||||
std::vector<int>
|
||||
Ref(const knowhere::sparse::SparseRow<float>* base,
|
||||
const knowhere::sparse::SparseRow<float>& query,
|
||||
int nb,
|
||||
int topk,
|
||||
const knowhere::MetricType& metric) {
|
||||
SearchRef(const knowhere::sparse::SparseRow<float>* base,
|
||||
const knowhere::sparse::SparseRow<float>& query,
|
||||
int nb,
|
||||
int topk) {
|
||||
std::vector<std::tuple<float, int>> res;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
auto& row = base[i];
|
||||
|
@ -50,6 +49,31 @@ Ref(const knowhere::sparse::SparseRow<float>* base,
|
|||
return offsets;
|
||||
}
|
||||
|
||||
std::vector<int>
|
||||
RangeSearchRef(const knowhere::sparse::SparseRow<float>* base,
|
||||
const knowhere::sparse::SparseRow<float>& query,
|
||||
int nb,
|
||||
float radius,
|
||||
float range_filter,
|
||||
int topk) {
|
||||
std::vector<int> offsets;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
auto& row = base[i];
|
||||
auto distance = row.dot(query);
|
||||
if (distance <= range_filter && distance > radius) {
|
||||
offsets.push_back(i);
|
||||
}
|
||||
}
|
||||
// select and sort top k on the range filter side
|
||||
std::sort(offsets.begin(), offsets.end(), [&](int a, int b) {
|
||||
return base[a].dot(query) > base[b].dot(query);
|
||||
});
|
||||
if (offsets.size() > topk) {
|
||||
offsets.resize(topk);
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
void
|
||||
AssertMatch(const std::vector<int>& expected, const int64_t* actual) {
|
||||
for (int i = 0; i < expected.size(); i++) {
|
||||
|
@ -95,11 +119,45 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
|||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref =
|
||||
Ref(base.get(), *(query.get() + i), nb, topk, metric_type);
|
||||
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
AssertMatch(ref, ans);
|
||||
}
|
||||
|
||||
search_info.search_params_[RADIUS] = 0.1;
|
||||
search_info.search_params_[RANGE_FILTER] = 0.5;
|
||||
auto result2 = BruteForceSearch(dataset,
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = RangeSearchRef(
|
||||
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);
|
||||
auto ans = result2.get_seg_offsets() + i * topk;
|
||||
AssertMatch(ref, ans);
|
||||
}
|
||||
|
||||
auto result3 = BruteForceSearchIterators(dataset,
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
auto iterators = result3.chunk_iterators();
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto it = iterators[i];
|
||||
auto q = *(query.get() + i);
|
||||
auto last_dis = std::numeric_limits<float>::max();
|
||||
// we should see strict decreasing distances for brute force iterator.
|
||||
while (it->HasNext()) {
|
||||
auto [offset, dis] = it->Next();
|
||||
ASSERT_LE(dis, last_dis);
|
||||
last_dis = dis;
|
||||
ASSERT_FLOAT_EQ(dis, base[offset].dot(q));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue