enhance: make search groupby stop when reaching topk groups (#35814)

related: #33544

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/35891/head^2
Chun Han 2024-09-02 18:25:03 +08:00 committed by GitHub
parent 57422cb2ed
commit 4641fd9195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 57 additions and 16 deletions

View File

@ -27,6 +27,7 @@ namespace milvus {
struct SearchInfo {
int64_t topk_{0};
int64_t group_size_{1};
bool group_strict_size_{false};
int64_t round_decimal_{0};
FieldId field_id_;
MetricType metric_type_;

View File

@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.group_strict_size_ = query_info_proto.group_strict_size();
}
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {

View File

@ -44,6 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int8_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -58,6 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int16_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -72,6 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int32_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -86,6 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int64_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -99,6 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<bool>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -113,6 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<std::string>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
@ -136,6 +142,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
@ -147,6 +154,7 @@ GroupIteratorsByType(
GroupIteratorResult<T>(iterator,
topK,
group_size,
group_strict_size,
data_getter,
group_by_values,
seg_offsets,
@ -161,13 +169,14 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
//1.
GroupByMap<T> groupMap(topK, group_size);
GroupByMap<T> groupMap(topK, group_size, group_strict_size);
//2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following
@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
//4. save groupBy results
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
offsets.push_back(std::get<0>(*iter));
distances.push_back(std::get<1>(*iter));
offsets.emplace_back(std::get<0>(*iter));
distances.emplace_back(std::get<1>(*iter));
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
}
}

View File

@ -182,6 +182,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
@ -195,19 +196,31 @@ struct GroupByMap {
std::unordered_map<T, int> group_map_{};
int group_capacity_{0};
int group_size_{0};
int enough_group_count{0};
int enough_group_count_{0};
bool strict_group_size_{false};
public:
GroupByMap(int group_capacity, int group_size)
: group_capacity_(group_capacity), group_size_(group_size){};
GroupByMap(int group_capacity,
int group_size,
bool strict_group_size = false)
: group_capacity_(group_capacity),
group_size_(group_size),
strict_group_size_(strict_group_size){};
bool
IsGroupResEnough() {
return group_map_.size() == group_capacity_ &&
enough_group_count == group_capacity_;
bool enough = false;
if (strict_group_size_) {
enough = group_map_.size() == group_capacity_ &&
enough_group_count_ == group_capacity_;
} else {
enough = group_map_.size() == group_capacity_;
}
return enough;
}
bool
Push(const T& t) {
if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) {
if (group_map_.size() >= group_capacity_ &&
group_map_.find(t) == group_map_.end()) {
return false;
}
if (group_map_[t] >= group_size_) {
@ -218,7 +231,7 @@ struct GroupByMap {
}
group_map_[t] += 1;
if (group_map_[t] >= group_size_) {
enough_group_count += 1;
enough_group_count_ += 1;
}
return true;
}
@ -229,6 +242,7 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,

View File

@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) {
search_params: "{\"ef\": 10}"
group_by_field_id: 101,
group_size: 5,
group_strict_size: true,
>
placeholder_tag: "$0"
@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) {
search_params: "{\"ef\": 10}"
group_by_field_id: 101
group_size: 3
group_strict_size: true
>
placeholder_tag: "$0"

View File

@ -62,6 +62,7 @@ message QueryInfo {
int64 group_by_field_id = 6;
bool materialized_view_involved = 7;
int64 group_size = 8;
bool group_strict_size = 9;
}
message ColumnInfo {

View File

@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}
}
var groupStrictSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
if err != nil {
groupStrictSize = false
} else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
if err != nil {
groupStrictSize = false
}
}
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
@ -140,12 +151,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}
return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
}, offset, nil
}

View File

@ -48,6 +48,7 @@ const (
IteratorField = "iterator"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"