mirror of https://github.com/milvus-io/milvus.git
enhance: ban range-search iteration for search-group-by (#30824)
related: #30033 Signed-off-by: MrPresent-Han <chun.han@zilliz.com>pull/31052/head
parent
c8efed6562
commit
3574bdf858
|
@ -203,7 +203,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
search_info.search_params_ =
|
||||
nlohmann::json::parse(query_info_proto.search_params());
|
||||
|
||||
if (query_info_proto.group_by_field_id() != 0) {
|
||||
if (query_info_proto.group_by_field_id() > 0) {
|
||||
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
|
||||
search_info.group_by_field_id_ = group_by_field_id;
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ import (
|
|||
const (
|
||||
IgnoreGrowingKey = "ignore_growing"
|
||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||
IteratorField = "iterator"
|
||||
GroupByFieldKey = "group_by_field"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -192,9 +193,8 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64
|
||||
var groupByFieldId int64 = -1
|
||||
if groupByFieldName != "" {
|
||||
groupByFieldId = -1
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.Name == groupByFieldName {
|
||||
|
@ -207,6 +207,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")
|
||||
}
|
||||
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do range-search when doing search-group-by")
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
|
|
|
@ -192,6 +192,14 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
|
|||
}
|
||||
}
|
||||
|
||||
func resetSearchParamsValue(kvs []*commonpb.KeyValuePair, keyName string, newVal string) {
|
||||
for _, kv := range kvs {
|
||||
if kv.GetKey() == keyName {
|
||||
kv.Value = newVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getInvalidSearchParams(invalidName string) []*commonpb.KeyValuePair {
|
||||
kvs := getValidSearchParams()
|
||||
for _, kv := range kvs {
|
||||
|
@ -2173,6 +2181,47 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
})
|
||||
}
|
||||
})
|
||||
t.Run("check iterator and groupBy", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "string_field",
|
||||
})
|
||||
fields := make([]*schemapb.FieldSchema, 0)
|
||||
fields = append(fields, &schemapb.FieldSchema{
|
||||
FieldID: int64(101),
|
||||
Name: "string_field",
|
||||
})
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
t.Run("check range-search and groupBy", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
resetSearchParamsValue(normalParam, SearchParamsKey, `{"nprobe": 10, "radius":0.2}`)
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "string_field",
|
||||
})
|
||||
fields := make([]*schemapb.FieldSchema, 0)
|
||||
fields = append(fields, &schemapb.FieldSchema{
|
||||
FieldID: int64(101),
|
||||
Name: "string_field",
|
||||
})
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
}
|
||||
|
||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||
|
|
Loading…
Reference in New Issue