mirror of https://github.com/milvus-io/milvus.git
enhance: make search groupby stop when reaching topk groups (#35814)
related: #33544 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/35891/head^2
parent
57422cb2ed
commit
4641fd9195
|
@ -27,6 +27,7 @@ namespace milvus {
|
|||
struct SearchInfo {
|
||||
int64_t topk_{0};
|
||||
int64_t group_size_{1};
|
||||
bool group_strict_size_{false};
|
||||
int64_t round_decimal_{0};
|
||||
FieldId field_id_;
|
||||
MetricType metric_type_;
|
||||
|
|
|
@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
search_info.group_size_ = query_info_proto.group_size() > 0
|
||||
? query_info_proto.group_size()
|
||||
: 1;
|
||||
search_info.group_strict_size_ = query_info_proto.group_strict_size();
|
||||
}
|
||||
|
||||
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
||||
|
|
|
@ -44,6 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<int8_t>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -58,6 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<int16_t>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -72,6 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<int32_t>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -86,6 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<int64_t>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -99,6 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<bool>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -113,6 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||
GroupIteratorsByType<std::string>(iterators,
|
||||
search_info.topk_,
|
||||
search_info.group_size_,
|
||||
search_info.group_strict_size_,
|
||||
*dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -136,6 +142,7 @@ GroupIteratorsByType(
|
|||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||
int64_t topK,
|
||||
int64_t group_size,
|
||||
bool group_strict_size,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
|
@ -147,6 +154,7 @@ GroupIteratorsByType(
|
|||
GroupIteratorResult<T>(iterator,
|
||||
topK,
|
||||
group_size,
|
||||
group_strict_size,
|
||||
data_getter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
|
@ -161,13 +169,14 @@ void
|
|||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||
int64_t topK,
|
||||
int64_t group_size,
|
||||
bool group_strict_size,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type) {
|
||||
//1.
|
||||
GroupByMap<T> groupMap(topK, group_size);
|
||||
GroupByMap<T> groupMap(topK, group_size, group_strict_size);
|
||||
|
||||
//2. do iteration until fill the whole map or run out of all data
|
||||
//note it may enumerate all data inside a segment and can block following
|
||||
|
@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
|||
|
||||
//4. save groupBy results
|
||||
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
|
||||
offsets.push_back(std::get<0>(*iter));
|
||||
distances.push_back(std::get<1>(*iter));
|
||||
offsets.emplace_back(std::get<0>(*iter));
|
||||
distances.emplace_back(std::get<1>(*iter));
|
||||
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -182,6 +182,7 @@ GroupIteratorsByType(
|
|||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||
int64_t topK,
|
||||
int64_t group_size,
|
||||
bool group_strict_size,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
|
@ -195,19 +196,31 @@ struct GroupByMap {
|
|||
std::unordered_map<T, int> group_map_{};
|
||||
int group_capacity_{0};
|
||||
int group_size_{0};
|
||||
int enough_group_count{0};
|
||||
int enough_group_count_{0};
|
||||
bool strict_group_size_{false};
|
||||
|
||||
public:
|
||||
GroupByMap(int group_capacity, int group_size)
|
||||
: group_capacity_(group_capacity), group_size_(group_size){};
|
||||
GroupByMap(int group_capacity,
|
||||
int group_size,
|
||||
bool strict_group_size = false)
|
||||
: group_capacity_(group_capacity),
|
||||
group_size_(group_size),
|
||||
strict_group_size_(strict_group_size){};
|
||||
bool
|
||||
IsGroupResEnough() {
|
||||
return group_map_.size() == group_capacity_ &&
|
||||
enough_group_count == group_capacity_;
|
||||
bool enough = false;
|
||||
if (strict_group_size_) {
|
||||
enough = group_map_.size() == group_capacity_ &&
|
||||
enough_group_count_ == group_capacity_;
|
||||
} else {
|
||||
enough = group_map_.size() == group_capacity_;
|
||||
}
|
||||
return enough;
|
||||
}
|
||||
bool
|
||||
Push(const T& t) {
|
||||
if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) {
|
||||
if (group_map_.size() >= group_capacity_ &&
|
||||
group_map_.find(t) == group_map_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (group_map_[t] >= group_size_) {
|
||||
|
@ -218,7 +231,7 @@ struct GroupByMap {
|
|||
}
|
||||
group_map_[t] += 1;
|
||||
if (group_map_[t] >= group_size_) {
|
||||
enough_group_count += 1;
|
||||
enough_group_count_ += 1;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -229,6 +242,7 @@ void
|
|||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||
int64_t topK,
|
||||
int64_t group_size,
|
||||
bool group_strict_size,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
|
|
|
@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) {
|
|||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 101,
|
||||
group_size: 5,
|
||||
group_strict_size: true,
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
|
@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) {
|
|||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 101
|
||||
group_size: 3
|
||||
group_strict_size: true
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ message QueryInfo {
|
|||
int64 group_by_field_id = 6;
|
||||
bool materialized_view_involved = 7;
|
||||
int64 group_size = 8;
|
||||
bool group_strict_size = 9;
|
||||
}
|
||||
|
||||
message ColumnInfo {
|
||||
|
|
|
@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
}
|
||||
|
||||
var groupStrictSize bool
|
||||
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
} else {
|
||||
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
}
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
|
@ -140,12 +151,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
GroupStrictSize: groupStrictSize,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ const (
|
|||
IteratorField = "iterator"
|
||||
GroupByFieldKey = "group_by_field"
|
||||
GroupSizeKey = "group_size"
|
||||
GroupStrictSize = "group_strict_size"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
|
|
Loading…
Reference in New Issue