diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 55fee86693..d499872b19 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -21,9 +21,12 @@ import ( ) type rankParams struct { - limit int64 - offset int64 - roundDecimal int64 + limit int64 + offset int64 + roundDecimal int64 + groupByFieldId int64 + groupSize int64 + groupStrictSize bool } func (r *rankParams) GetLimit() int64 { @@ -47,6 +50,27 @@ func (r *rankParams) GetRoundDecimal() int64 { return 0 } +func (r *rankParams) GetGroupByFieldId() int64 { + if r != nil { + return r.groupByFieldId + } + return -1 +} + +func (r *rankParams) GetGroupSize() int64 { + if r != nil { + return r.groupSize + } + return 1 +} + +func (r *rankParams) GetGroupStrictSize() bool { + if r != nil { + return r.groupStrictSize + } + return false +} + func (r *rankParams) String() string { return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal()) } @@ -137,51 +161,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } // 5. parse group by field and group by size - groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair) - if err != nil { - groupByFieldName = "" - } - var groupByFieldId int64 = -1 - if groupByFieldName != "" { - fields := schema.GetFields() - for _, field := range fields { - if field.GetNullable() { - return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName)) - } - if field.Name == groupByFieldName { - groupByFieldId = field.FieldID - break - } - } - if groupByFieldId == -1 { - return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema") - } - } - - var groupSize int64 - groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair) - if err != nil { - groupSize = 1 - } else { - groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64) - if err != nil || groupSize <= 0 { - groupSize = 1 - } - } - if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() { - return nil, 0, merr.WrapErrParameterInvalidMsg( - fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64())) - } - + var groupByFieldId, groupSize int64 var groupStrictSize bool - groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) - if err != nil { - groupStrictSize = false + if isAdvanced { + groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize() } else { - groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) - if err != nil { - groupStrictSize = false + groupByInfo := parseGroupByInfo(searchParamsPair, schema) + if groupByInfo.err != nil { + return nil, 0, groupByInfo.err } + groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize() } // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search @@ -298,8 +287,104 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string, return partitionsSet.Collect(), nil } +type groupByInfo struct { + groupByFieldId int64 + groupSize int64 + groupStrictSize bool + err error +} + +func (g *groupByInfo) GetGroupByFieldId() int64 { + if g != nil { + return g.groupByFieldId + } + return 0 +} + +func (g *groupByInfo) GetGroupSize() int64 { + if g != nil { + return g.groupSize + } + return 0 +} + +func (g *groupByInfo) GetGroupStrictSize() bool { + if g != nil { + return g.groupStrictSize + } + return false +} + +func (g *groupByInfo) GetError() error { + if g != nil { + return g.err + } + return nil +} + +func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo { + ret := &groupByInfo{} + + // 1. parse group_by_field + groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair) + if err != nil { + groupByFieldName = "" + } + var groupByFieldId int64 = -1 + if groupByFieldName != "" { + fields := schema.GetFields() + for _, field := range fields { + if field.GetNullable() { + ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName)) + return ret + } + if field.Name == groupByFieldName { + groupByFieldId = field.FieldID + break + } + } + if groupByFieldId == -1 { + ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema") + return ret + } + } + ret.groupByFieldId = groupByFieldId + + // 2. parse group size + var groupSize int64 + groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair) + if err != nil { + groupSize = 1 + } else { + groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64) + if err != nil || groupSize <= 0 { + groupSize = 1 + } + } + if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() { + ret.err = merr.WrapErrParameterInvalidMsg( + fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64())) + return ret + } + ret.groupSize = groupSize + + // 3. parse group strict size + var groupStrictSize bool + groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) + if err != nil { + groupStrictSize = false + } else { + groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) + if err != nil { + groupStrictSize = false + } + } + ret.groupStrictSize = groupStrictSize + return ret +} + // parseRankParams get limit and offset from rankParams, both are optional. -func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) { +func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) { var ( limit int64 offset int64 @@ -343,23 +428,30 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) } + // parse group_by parameters from main request body for hybrid search + groupByInfo := parseGroupByInfo(rankParamsPair, schema) + if groupByInfo.err != nil { + return nil, groupByInfo.err + } + return &rankParams{ - limit: limit, - offset: offset, - roundDecimal: roundDecimal, + limit: limit, + offset: offset, + roundDecimal: roundDecimal, + groupByFieldId: groupByInfo.GetGroupByFieldId(), + groupSize: groupByInfo.GetGroupSize(), + groupStrictSize: groupByInfo.GetGroupStrictSize(), }, nil } func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest { - searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams())) - copy(searchParams, req.GetRankParams()) ret := &milvuspb.SearchRequest{ Base: req.GetBase(), DbName: req.GetDbName(), CollectionName: req.GetCollectionName(), PartitionNames: req.GetPartitionNames(), OutputFields: req.GetOutputFields(), - SearchParams: searchParams, + SearchParams: req.GetRankParams(), TravelTimestamp: req.GetTravelTimestamp(), GuaranteeTimestamp: req.GetGuaranteeTimestamp(), Nq: 0, diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 08bc373d00..ccd4e591ca 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -170,7 +170,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } } if t.SearchRequest.GetIsAdvanced() { - t.rankParams, err = parseRankParams(t.request.GetSearchParams()) + t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema) if err != nil { log.Info("parseRankParams failed", zap.Error(err)) return err @@ -366,8 +366,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { Topk: queryInfo.GetTopk(), Offset: offset, MetricType: queryInfo.GetMetricType(), - GroupByFieldId: queryInfo.GetGroupByFieldId(), - GroupSize: queryInfo.GetGroupSize(), + GroupByFieldId: t.rankParams.GetGroupByFieldId(), + GroupSize: t.rankParams.GetGroupSize(), } // set PartitionIDs for sub search @@ -407,10 +407,9 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. } - if len(t.queryInfos) > 0 { - t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId() - t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize() - } + + t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() + t.SearchRequest.GroupSize = t.rankParams.GetGroupSize() // used for requery if t.partitionKeyMode { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 216ada2677..70799c9fde 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2263,6 +2263,63 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { assert.Equal(t, int64(0), offset) }) + t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 101, Name: "c1"}, + {FieldID: 102, Name: "c2"}, + {FieldID: 103, Name: "c3"}, + }, + } + // 1. first parse rank params + // outer params require to group by field 101 and groupSize=3 and groupStrictSize=false + testRankParamsPairs := getValidSearchParams() + testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "c1", + }) + testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ + Key: GroupSizeKey, + Value: strconv.FormatInt(3, 10), + }) + testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ + Key: GroupStrictSize, + Value: "false", + }) + testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ + Key: LimitKey, + Value: "100", + }) + testRankParams, err := parseRankParams(testRankParamsPairs, schema) + assert.NoError(t, err) + + // 2. parse search params for sub request in hybridsearch + params := getValidSearchParams() + // inner params require to group by field 103 and groupSize=10 and groupStrictSize=true + params = append(params, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "c3", + }) + params = append(params, &commonpb.KeyValuePair{ + Key: GroupSizeKey, + Value: strconv.FormatInt(10, 10), + }) + params = append(params, &commonpb.KeyValuePair{ + Key: GroupStrictSize, + Value: "true", + }) + + info, _, err := parseSearchInfo(params, schema, testRankParams) + assert.NoError(t, err) + assert.NotNil(t, info) + + // all group_by related parameters should be aligned to parameters + // set by main request rather than inner sub request + assert.Equal(t, int64(101), info.GetGroupByFieldId()) + assert.Equal(t, int64(3), info.GetGroupSize()) + assert.False(t, info.GetGroupStrictSize()) + }) + t.Run("parseSearchInfo error", func(t *testing.T) { spNoTopk := []*commonpb.KeyValuePair{{ Key: AnnsFieldKey,