enhance: reduce copy of bitset and id conversion of brurtforce search (#37675)

issue: https://github.com/milvus-io/milvus/issues/37798

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
pull/36865/head
cqy123456 2024-11-19 15:48:40 +08:00 committed by GitHub
parent b6612e02b4
commit 8216345b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 127 additions and 140 deletions

View File

@ -79,8 +79,6 @@ struct VectorIterator {
heap_.pop();
if (iterators_[top->GetIteratorIdx()]->HasNext()) {
auto origin_pair = iterators_[top->GetIteratorIdx()]->Next();
origin_pair.first = convert_to_segment_offset(
origin_pair.first, top->GetIteratorIdx());
auto off_dis_pair = std::make_shared<OffsetDisPair>(
origin_pair, top->GetIteratorIdx());
heap_.push(off_dis_pair);
@ -108,8 +106,6 @@ struct VectorIterator {
for (auto& iter : iterators_) {
if (iter->HasNext()) {
auto origin_pair = iter->Next();
origin_pair.first =
convert_to_segment_offset(origin_pair.first, idx);
auto off_dis_pair =
std::make_shared<OffsetDisPair>(origin_pair, idx++);
heap_.push(off_dis_pair);

View File

@ -71,28 +71,49 @@ PrepareBFSearchParams(const SearchInfo& search_info,
return search_cfg;
}
std::pair<knowhere::DataSetPtr, knowhere::DataSetPtr>
PrepareBFDataSet(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
DataType data_type) {
auto base_dataset =
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(query_dataset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(query_dataset);
}
base_dataset->SetTensorBeginId(raw_ds.begin_id);
return std::make_pair(query_dataset, base_dataset);
};
SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal);
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto topk = dataset.topk;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
SubSearchResult sub_result(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal);
auto topk = query_ds.topk;
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
@ -112,10 +133,12 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
res = knowhere::BruteForce::RangeSearch<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
res = knowhere::BruteForce::RangeSearch<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BINARY) {
res = knowhere::BruteForce::RangeSearch<uint8_t>(
@ -138,7 +161,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
res.what());
}
auto result =
ReGenRangeSearchResult(res.value(), topk, nq, dataset.metric_type);
ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type);
milvus::tracer::AddEvent("ReGenRangeSearchResult");
std::copy_n(
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
@ -155,7 +178,8 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
stat = knowhere::BruteForce::SearchWithBuf<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
@ -163,7 +187,8 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
stat = knowhere::BruteForce::SearchWithBuf<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
@ -202,21 +227,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
}
SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
@ -227,11 +246,13 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_FLOAT16:
iterators_val = knowhere::BruteForce::AnnIterator<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_BFLOAT16:
iterators_val = knowhere::BruteForce::AnnIterator<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_SPARSE_FLOAT:
@ -251,10 +272,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
"equal to nq:{} for single chunk",
iterators_val.value().size(),
nq);
SubSearchResult subSearchResult(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal,
SubSearchResult subSearchResult(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal,
iterators_val.value());
return std::move(subSearchResult);
} else {

View File

@ -24,18 +24,16 @@ CheckBruteForceSearchParam(const FieldMeta& field,
const SearchInfo& search_info);
SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,

View File

@ -136,31 +136,23 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;
auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
if (info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(search_dataset,
chunk_data,
size_per_chunk,
sub_data,
info,
index_info,
sub_view,
bitset,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(search_dataset,
chunk_data,
size_per_chunk,
sub_data,
info,
index_info,
sub_view,
bitset,
data_type);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
}
final_qr.merge(sub_qr);
}
}

View File

@ -95,12 +95,12 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();
query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
@ -116,51 +116,27 @@ SearchOnSealed(const Schema& schema,
auto vec_data = column->Data(i);
auto chunk_size = column->chunk_row_nums(i);
const uint8_t* bitset_ptr = nullptr;
bool aligned = false;
if ((offset & 0x7) == 0) {
bitset_ptr = bitview.data() + (offset >> 3);
aligned = true;
} else {
char* bitset_data = new char[(chunk_size + 7) / 8];
std::fill(bitset_data, bitset_data + sizeof(bitset_data), 0);
bitset::detail::ElementWiseBitsetPolicy<char>::op_copy(
reinterpret_cast<const char*>(bitview.data()),
offset,
bitset_data,
0,
chunk_size);
bitset_ptr = reinterpret_cast<const uint8_t*>(bitset_data);
}
BitsetView bitset_view(bitset_ptr, chunk_size);
auto data_id = offset;
auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
chunk_size,
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
bitview,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(dataset,
vec_data,
chunk_size,
auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
bitview,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
if (o != -1) {
o += offset;
}
}
final_qr.merge(sub_qr);
}
if (!aligned) {
delete[] bitset_ptr;
}
offset += chunk_size;
}
if (search_info.group_by_field_id_.has_value()) {
@ -172,8 +148,8 @@ SearchOnSealed(const Schema& schema,
result.distances_ = std::move(final_qr.mutable_distances());
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}
void
@ -194,19 +170,19 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();
query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
row_count,
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
@ -214,9 +190,8 @@ SearchOnSealed(const Schema& schema,
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(dataset,
vec_data,
row_count,
auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
@ -224,8 +199,8 @@ SearchOnSealed(const Schema& schema,
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}
} // namespace milvus::query

View File

@ -19,7 +19,12 @@
namespace milvus::query {
namespace dataset {
struct RawDataset {
int64_t begin_id = 0;
int64_t dim;
int64_t num_raw_data;
const void* raw_data;
};
struct SearchDataset {
knowhere::MetricType metric_type;
int64_t num_queries;

View File

@ -124,7 +124,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
auto query = GenFloatVecs(dim, nq, metric_type);
auto index_info = std::map<std::string, std::string>{};
dataset::SearchDataset dataset{
dataset::SearchDataset query_dataset{
metric_type, nq, topk, -1, dim, query.data()};
if (!is_supported_float_metric(metric_type)) {
// Memory leak in knowhere.
@ -134,9 +134,10 @@ class TestFloatSearchBruteForce : public ::testing::Test {
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
auto result = BruteForceSearch(dataset,
base.data(),
nb,
auto raw_dataset = query::dataset::RawDataset{0, dim, nb, base.data()};
auto result = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,

View File

@ -103,21 +103,21 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
dataset::SearchDataset dataset{
dataset::SearchDataset query_dataset{
metric_type, nq, topk, -1, kTestSparseDim, query.get()};
auto raw_dataset =
query::dataset::RawDataset{0, kTestSparseDim, nb, base.get()};
if (!is_supported_sparse_float_metric(metric_type)) {
ASSERT_ANY_THROW(BruteForceSearch(dataset,
base.get(),
nb,
ASSERT_ANY_THROW(BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
}
auto result = BruteForceSearch(dataset,
base.get(),
nb,
auto result = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
@ -130,9 +130,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
search_info.search_params_[RADIUS] = 0.1;
search_info.search_params_[RANGE_FILTER] = 0.5;
auto result2 = BruteForceSearch(dataset,
base.get(),
nb,
auto result2 = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
@ -144,9 +143,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
AssertMatch(ref, ans);
}
auto result3 = BruteForceSearchIterators(dataset,
base.get(),
nb,
auto result3 = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,

View File

@ -177,9 +177,10 @@ TEST(Indexing, BinaryBruteForce) {
search_info.topk_ = topk;
search_info.round_decimal_ = round_decimal;
search_info.metric_type_ = metric_type;
auto base_dataset = query::dataset::RawDataset{
int64_t(0), dim, N, (const void*)bin_vec.data()};
auto sub_result = query::BruteForceSearch(search_dataset,
bin_vec.data(),
N,
base_dataset,
search_info,
index_info,
nullptr,

View File

@ -1254,9 +1254,9 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
search_info.topk_ = topk;
search_info.round_decimal_ = round_decimal;
search_info.metric_type_ = metric_type;
auto raw_dataset = query::dataset::RawDataset{0, dim, N, vec_col.data()};
auto sub_result = BruteForceSearch(search_dataset,
vec_col.data(),
N,
raw_dataset,
search_info,
index_info,
nullptr,