diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index d782ec1609..84fb6e389f 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -203,7 +203,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.search_params_ = nlohmann::json::parse(query_info_proto.search_params()); - if (query_info_proto.group_by_field_id() != 0) { + if (query_info_proto.group_by_field_id() > 0) { auto group_by_field_id = FieldId(query_info_proto.group_by_field_id()); search_info.group_by_field_id_ = group_by_field_id; } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 850f271234..ee6feec8e5 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -44,6 +44,7 @@ import ( const ( IgnoreGrowingKey = "ignore_growing" ReduceStopForBestKey = "reduce_stop_for_best" + IteratorField = "iterator" GroupByFieldKey = "group_by_field" AnnsFieldKey = "anns_field" TopKKey = "topk" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 909e87d1e3..e3a52f0d5b 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -7,6 +7,7 @@ import ( "math" "regexp" "strconv" + "strings" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" @@ -192,9 +193,8 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb if err != nil { groupByFieldName = "" } - var groupByFieldId int64 + var groupByFieldId int64 = -1 if groupByFieldName != "" { - groupByFieldId = -1 fields := schema.GetFields() for _, field := range fields { if field.Name == groupByFieldName { @@ -207,6 +207,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } + // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search + isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + if isIterator == "True" && groupByFieldId > 0 { + return nil, 0, merr.WrapErrParameterInvalid("", "", + "Not allowed to do groupBy when doing iteration") + } + if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 { + return nil, 0, merr.WrapErrParameterInvalid("", "", + "Not allowed to do range-search when doing search-group-by") + } + return &planpb.QueryInfo{ Topk: queryTopK, MetricType: metricType, diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 8bd6dd19b8..73e71cb695 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -192,6 +192,14 @@ func getValidSearchParams() []*commonpb.KeyValuePair { } } +func resetSearchParamsValue(kvs []*commonpb.KeyValuePair, keyName string, newVal string) { + for _, kv := range kvs { + if kv.GetKey() == keyName { + kv.Value = newVal + } + } +} + func getInvalidSearchParams(invalidName string) []*commonpb.KeyValuePair { kvs := getValidSearchParams() for _, kv := range kvs { @@ -2173,6 +2181,47 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { }) } }) + t.Run("check iterator and groupBy", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "string_field", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + info, _, err := parseSearchInfo(normalParam, schema) + assert.Nil(t, info) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + }) + t.Run("check range-search and groupBy", func(t *testing.T) { + normalParam := getValidSearchParams() + resetSearchParamsValue(normalParam, SearchParamsKey, `{"nprobe": 10, "radius":0.2}`) + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "string_field", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + info, _, err := parseSearchInfo(normalParam, schema) + assert.Nil(t, info) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {