From bfd9d86fe917dd277331eb7da398f69f59a2920f Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:21:00 +0800 Subject: [PATCH] feat: support groupby size on go-layer(#33544) (#33845) related: #33544 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/proto/internal.proto | 6 + internal/proxy/search_reduce_util.go | 113 +++++--- internal/proxy/search_util.go | 16 +- internal/proxy/task.go | 1 + internal/proxy/task_search.go | 2 +- internal/proxy/task_search_test.go | 78 ++++++ internal/querynodev2/delegator/delegator.go | 9 +- internal/querynodev2/handlers.go | 6 +- internal/querynodev2/segments/result.go | 106 +------ internal/querynodev2/segments/result_test.go | 171 ------------ .../querynodev2/segments/search_reduce.go | 227 +++++++++++++++ .../segments/search_reduce_test.go | 258 ++++++++++++++++++ internal/querynodev2/services.go | 6 +- 13 files changed, 683 insertions(+), 316 deletions(-) create mode 100644 internal/querynodev2/segments/search_reduce.go create mode 100644 internal/querynodev2/segments/search_reduce_test.go diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 980cf35769..f01c5256e4 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -96,6 +96,11 @@ message SubSearchRequest { string metricType = 9; } +message ExtraSearchParam { + int64 group_by_field_id = 1; + int64 group_size = 2; +} + message SearchRequest { common.MsgBase base = 1; int64 reqID = 2; @@ -120,6 +125,7 @@ message SearchRequest { bool is_advanced = 20; int64 offset = 21; common.ConsistencyLevel consistency_level = 22; + ExtraSearchParam extra_search_param = 23; } message SubSearchResults { diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index ecc77e39d5..488f0ef01d 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -58,7 +58,8 @@ func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) reduceInfo.topK, reduceInfo.metricType, reduceInfo.pkType, - reduceInfo.offset) + reduceInfo.offset, + reduceInfo.queryInfo.GroupSize) } return reduceSearchResultDataNoGroupBy(ctx, reduceInfo.subSearchResultData, @@ -69,7 +70,21 @@ func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) reduceInfo.offset) } -func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) { +type MilvusPKType interface{} + +type groupReduceInfo struct { + subSearchIdx int + resultIdx int64 + score float32 + id MilvusPKType +} + +func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, + nq int64, topk int64, metricType string, + pkType schemapb.DataType, + offset int64, + groupSize int64, +) (*milvuspb.SearchResults, error) { tr := timerecord.NewTimeRecorder("reduceSearchResultData") defer func() { tr.CtxElapse(ctx, "done") @@ -130,13 +145,15 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData var ( subSearchNum = len(subSearchResultData) // for results of each subSearchResultData, storing the start offset of each query of nq queries - subSearchNqOffset = make([][]int64, subSearchNum) + subSearchNqOffset = make([][]int64, subSearchNum) + totalResCount int64 = 0 ) for i := 0; i < subSearchNum; i++ { subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) for j := int64(1); j < nq; j++ { subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] } + totalResCount += subSearchNqOffset[i][nq-1] } var ( @@ -146,6 +163,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + groupBound := groupSize * limit // reducing nq * topk results for i := int64(0); i < nq; i++ { @@ -154,16 +172,15 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData // sum(cursors) == j cursors = make([]int64, subSearchNum) - j int64 - idSet = make(map[interface{}]struct{}) - groupByValSet = make(map[interface{}]struct{}) + j int64 + pkSet = make(map[interface{}]struct{}) + groupByValMap = make(map[interface{}][]*groupReduceInfo) + skipOffsetMap = make(map[interface{}]bool) + groupByValList = make([]interface{}, limit) + groupByValIdx = 0 ) - // keep limit results - for j = 0; j < limit; { - // From all the sub-query result sets of the i-th query vector, - // find the sub-query result set index of the score j-th data, - // and the index of the data in schemapb.SearchResultData + for j = 0; j < groupBound; { subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) if subSearchIdx == -1 { break @@ -171,44 +188,63 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData subSearchRes := subSearchResultData[subSearchIdx] id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) - score := subSearchRes.Scores[resultDataIdx] + score := subSearchRes.GetScores()[resultDataIdx] groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx)) if groupByVal == nil { return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," + "there must be sth wrong on queryNode side") } - // remove duplicates - if _, ok := idSet[id]; !ok { - _, groupByValExist := groupByValSet[groupByVal] - if !groupByValExist { - groupByValSet[groupByVal] = struct{}{} - if int64(len(groupByValSet)) <= offset { - continue - // skip offset groups + if _, ok := pkSet[id]; !ok { + if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] { + skipOffsetMap[groupByVal] = true + // the first offset's group will be ignored + skipDupCnt++ + } else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit { + // skip when groupbyMap has been full and found new groupByVal + skipDupCnt++ + } else if int64(len(groupByValMap[groupByVal])) >= groupSize { + // skip when target group has been full + skipDupCnt++ + } else { + if len(groupByValMap[groupByVal]) == 0 { + groupByValList[groupByValIdx] = groupByVal + groupByValIdx++ } - retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) - typeutil.AppendPKs(ret.Results.Ids, id) - ret.Results.Scores = append(ret.Results.Scores, score) - idSet[id] = struct{}{} - if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil { + groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{ + subSearchIdx: subSearchIdx, + resultIdx: resultDataIdx, id: id, score: score, + }) + pkSet[id] = struct{}{} + j++ + } + } else { + skipDupCnt++ + } + + cursors[subSearchIdx]++ + } + + // assemble all eligible values in group + // values in groupByValList is sorted by the highest score in each group + for _, groupVal := range groupByValList { + if groupVal != nil { + groupEntities := groupByValMap[groupVal] + for _, groupEntity := range groupEntities { + subResData := subSearchResultData[groupEntity.subSearchIdx] + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx) + typeutil.AppendPKs(ret.Results.Ids, groupEntity.id) + ret.Results.Scores = append(ret.Results.Scores, groupEntity.score) + if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil { log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) return ret, err } - j++ - } else { - // skip entity with same groupby - skipDupCnt++ } - } else { - // skip entity with same id - skipDupCnt++ } - cursors[subSearchIdx]++ } + if realTopK != -1 && realTopK != j { log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) - // return nil, errors.New("the length (topk) between all result of query is different") } realTopK = j ret.Results.Topks = append(ret.Results.Topks, realTopK) @@ -218,10 +254,13 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } } - log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + log.Ctx(ctx).Debug("skip duplicated search result when doing group by", zap.Int64("count", skipDupCnt)) - if skipDupCnt > 0 { - log.Ctx(ctx).Info("skip duplicated search result", zap.Int64("count", skipDupCnt)) + if float64(skipDupCnt) >= float64(totalResCount)*0.3 { + log.Warn("GroupBy reduce skipped too many results, "+ + "this may influence the final result seriously", + zap.Int64("skipDupCnt", skipDupCnt), + zap.Int64("groupBound", groupBound)) } ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 0ea8a4e6d2..06f2ff4a0a 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -99,7 +99,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb searchParamStr = "" } - // 5. parse group by field + // 5. parse group by field and group by size groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair) if err != nil { groupByFieldName = "" @@ -118,7 +118,18 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } - // 6. disable groupBy for iterator and range search + var groupSize int64 + groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair) + if err != nil { + groupSize = 1 + } else { + groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64) + if err != nil || groupSize <= 0 { + groupSize = 1 + } + } + + // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search if isIterator == "True" && groupByFieldId > 0 { return nil, 0, merr.WrapErrParameterInvalid("", "", "Not allowed to do groupBy when doing iteration") @@ -134,6 +145,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb SearchParams: searchParamStr, RoundDecimal: roundDecimal, GroupByFieldId: groupByFieldId, + GroupSize: groupSize, }, offset, nil } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 197e38e3e4..94ace67712 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -49,6 +49,7 @@ const ( ReduceStopForBestKey = "reduce_stop_for_best" IteratorField = "iterator" GroupByFieldKey = "group_by_field" + GroupSizeKey = "group_size" AnnsFieldKey = "anns_field" TopKKey = "topk" NQKey = "nq" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index e32603e4b2..aff3d06690 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -456,12 +456,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { if err != nil { return err } - t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup t.SearchRequest.Topk = queryInfo.GetTopk() t.SearchRequest.MetricType = queryInfo.GetMetricType() t.queryInfos = append(t.queryInfos, queryInfo) t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 + t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize} log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index b0f9d769b2..86f8d9edcc 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1706,6 +1706,7 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) { } queryInfo := &planpb.QueryInfo{ GroupByFieldId: 1, + GroupSize: 1, } reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2, schemapb.DataType_Int64, 0, queryInfo)) @@ -1765,6 +1766,7 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) { queryInfo := &planpb.QueryInfo{ GroupByFieldId: 1, + GroupSize: 1, } reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2, schemapb.DataType_Int64, offset, queryInfo)) @@ -1777,6 +1779,82 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) { assert.NoError(t, err) } +func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) { + var ( + nq int64 = 2 + topK int64 = 5 + ) + ids := [][]int64{ + {1, 3, 5, 7, 9, 1, 3, 5, 7, 9}, + {2, 4, 6, 8, 10, 2, 4, 6, 8, 10}, + } + scores := [][]float32{ + {10, 8, 6, 4, 2, 10, 8, 6, 4, 2}, + {9, 7, 5, 3, 1, 9, 7, 5, 3, 1}, + } + + groupByValuesArr := [][][]int64{ + { + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + }, + { + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + {6, 8, 3, 4, 5, 6, 8, 3, 4, 5}, + }, + } + expectedIDs := [][]int64{ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + } + expectedScores := [][]float32{ + {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1}, + {-10, -9, -8, -7, -6, -5, -10, -9, -8, -7, -6, -5}, + } + expectedGroupByValues := [][]int64{ + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, + {1, 6, 2, 8, 3, 3, 1, 6, 2, 8, 3, 3}, + } + + for i, groupByValues := range groupByValuesArr { + t.Run("Group By correctness", func(t *testing.T) { + var results []*schemapb.SearchResultData + for j := range ids { + result := getSearchResultData(nq, topK) + result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}} + result.Scores = scores[j] + result.Topks = []int64{topK, topK} + result.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: groupByValues[j], + }, + }, + }, + }, + } + results = append(results, result) + } + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: 1, + GroupSize: 2, + } + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2, + schemapb.DataType_Int64, 0, queryInfo)) + resultIDs := reduced.GetResults().GetIds().GetIntId().Data + resultScores := reduced.GetResults().GetScores() + resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() + assert.EqualValues(t, expectedIDs[i], resultIDs) + assert.EqualValues(t, expectedScores[i], resultScores) + assert.EqualValues(t, expectedGroupByValues[i], resultGroupByValues) + assert.NoError(t, err) + }) + } +} + func TestSearchTask_ErrExecute(t *testing.T) { var ( err error diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 1d11b4c5ad..3ae41a2feb 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -351,9 +351,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return segments.ReduceSearchResults(ctx, results, - searchReq.Req.GetNq(), - searchReq.Req.GetTopk(), - searchReq.Req.GetMetricType()) + segments.NewReduceInfo(searchReq.Req.GetNq(), + searchReq.Req.GetTopk(), + searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(), + searchReq.Req.GetExtraSearchParam().GetGroupSize(), + searchReq.Req.GetMetricType()), + ) }) futures[index] = future } diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 170af4c39e..bf71bf74d6 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -388,7 +388,11 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq if req.GetReq().GetIsAdvanced() { resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq()) } else { - resp, err = segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) + resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(), + req.Req.GetTopk(), + req.Req.GetExtraSearchParam().GetGroupByFieldId(), + req.Req.GetExtraSearchParam().GetGroupSize(), + req.Req.GetMetricType())) } if err != nil { return nil, err diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 38fc5b9427..fc43edb5c5 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -42,7 +42,7 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{} var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{} -func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) { +func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) { results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool { return result != nil && result.GetSlicedBlob() != nil }) @@ -60,8 +60,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult channelsMvcc[ch] = ts } // shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty - if metricType == "" { - metricType = r.MetricType + if info.metricType == "" { + info.metricType = r.MetricType } } log := log.Ctx(ctx) @@ -80,12 +80,13 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult zap.Int64("topk", sData.TopK)) } - reducedResultData, err := ReduceSearchResultData(ctx, searchResultData, nq, topk) + searchReduce := InitSearchReducer(info) + reducedResultData, err := searchReduce.ReduceSearchResultData(ctx, searchResultData, info) if err != nil { log.Warn("shard leader reduce errors", zap.Error(err)) return nil, err } - searchResults, err := EncodeSearchResultData(ctx, reducedResultData, nq, topk, metricType) + searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType) if err != nil { log.Warn("shard leader encode search result errors", zap.Error(err)) return nil, err @@ -207,101 +208,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes return searchResults, nil } -func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) { - ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData") - defer sp.End() - log := log.Ctx(ctx) - - if len(searchResultData) == 0 { - return &schemapb.SearchResultData{ - NumQueries: nq, - TopK: topk, - FieldsData: make([]*schemapb.FieldData, 0), - Scores: make([]float32, 0), - Ids: &schemapb.IDs{}, - Topks: make([]int64, 0), - }, nil - } - ret := &schemapb.SearchResultData{ - NumQueries: nq, - TopK: topk, - FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), - Scores: make([]float32, 0), - Ids: &schemapb.IDs{}, - Topks: make([]int64, 0), - } - - resultOffsets := make([][]int64, len(searchResultData)) - for i := 0; i < len(searchResultData); i++ { - resultOffsets[i] = make([]int64, len(searchResultData[i].Topks)) - for j := int64(1); j < nq; j++ { - resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1] - } - ret.AllSearchCount += searchResultData[i].GetAllSearchCount() - } - - var skipDupCnt int64 - var retSize int64 - maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - for i := int64(0); i < nq; i++ { - offsets := make([]int64, len(searchResultData)) - - idSet := make(map[interface{}]struct{}) - groupByValueSet := make(map[interface{}]struct{}) - var j int64 - for j = 0; j < topk; { - sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i) - if sel == -1 { - break - } - idx := resultOffsets[sel][i] + offsets[sel] - - id := typeutil.GetPK(searchResultData[sel].GetIds(), idx) - groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx)) - score := searchResultData[sel].Scores[idx] - - // remove duplicates - if _, ok := idSet[id]; !ok { - groupByValExist := false - if groupByVal != nil { - _, groupByValExist = groupByValueSet[groupByVal] - } - if !groupByValExist { - retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) - typeutil.AppendPKs(ret.Ids, id) - ret.Scores = append(ret.Scores, score) - if groupByVal != nil { - groupByValueSet[groupByVal] = struct{}{} - if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil { - log.Error("Failed to append groupByValues", zap.Error(err)) - return ret, err - } - } - idSet[id] = struct{}{} - j++ - } - } else { - // skip entity with same id - skipDupCnt++ - } - offsets[sel]++ - } - - // if realTopK != -1 && realTopK != j { - // log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) - // // return nil, errors.New("the length (topk) between all result of query is different") - // } - ret.Topks = append(ret.Topks, j) - - // limit search result to avoid oom - if retSize > maxOutputSize { - return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) - } - } - log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) - return ret, nil -} - func SelectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int { var ( sel = -1 diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 794321ce12..4f4d4fead7 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -702,177 +702,6 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { }) } -func (suite *ResultSuite) TestResult_ReduceSearchResultData() { - const ( - nq = 1 - topk = 4 - metricType = "L2" - ) - suite.Run("case1", func() { - ids := []int64{1, 2, 3, 4} - scores := []float32{-1.0, -2.0, -3.0, -4.0} - topks := []int64{int64(len(ids))} - data1 := genSearchResultData(nq, topk, ids, scores, topks) - data2 := genSearchResultData(nq, topk, ids, scores, topks) - dataArray := make([]*schemapb.SearchResultData, 0) - dataArray = append(dataArray, data1) - dataArray = append(dataArray, data2) - res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk) - suite.Nil(err) - suite.Equal(ids, res.Ids.GetIntId().Data) - suite.Equal(scores, res.Scores) - }) - suite.Run("case2", func() { - ids1 := []int64{1, 2, 3, 4} - scores1 := []float32{-1.0, -2.0, -3.0, -4.0} - topks1 := []int64{int64(len(ids1))} - ids2 := []int64{5, 1, 3, 4} - scores2 := []float32{-1.0, -1.0, -3.0, -4.0} - topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) - dataArray := make([]*schemapb.SearchResultData, 0) - dataArray = append(dataArray, data1) - dataArray = append(dataArray, data2) - res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk) - suite.Nil(err) - suite.ElementsMatch([]int64{1, 5, 2, 3}, res.Ids.GetIntId().Data) - }) -} - -func (suite *ResultSuite) TestResult_SearchGroupByResult() { - const ( - nq = 1 - topk = 4 - ) - suite.Run("reduce_group_by_int", func() { - ids1 := []int64{1, 2, 3, 4} - scores1 := []float32{-1.0, -2.0, -3.0, -4.0} - topks1 := []int64{int64(len(ids1))} - ids2 := []int64{5, 1, 3, 4} - scores2 := []float32{-1.0, -1.0, -3.0, -4.0} - topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) - data1.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_Int8, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: []int32{2, 3, 4, 5}, - }, - }, - }, - }, - } - data2.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_Int8, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: []int32{2, 3, 4, 5}, - }, - }, - }, - }, - } - dataArray := make([]*schemapb.SearchResultData, 0) - dataArray = append(dataArray, data1) - dataArray = append(dataArray, data2) - res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk) - suite.Nil(err) - suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data) - suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores) - suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data) - }) - suite.Run("reduce_group_by_bool", func() { - ids1 := []int64{1, 2} - scores1 := []float32{-1.0, -2.0} - topks1 := []int64{int64(len(ids1))} - ids2 := []int64{3, 4} - scores2 := []float32{-1.0, -1.0} - topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) - data1.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: []bool{true, false}, - }, - }, - }, - }, - } - data2.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: []bool{true, false}, - }, - }, - }, - }, - } - dataArray := make([]*schemapb.SearchResultData, 0) - dataArray = append(dataArray, data1) - dataArray = append(dataArray, data2) - res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk) - suite.Nil(err) - suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data) - suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores) - suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data) - }) - suite.Run("reduce_group_by_string", func() { - ids1 := []int64{1, 2, 3, 4} - scores1 := []float32{-1.0, -2.0, -3.0, -4.0} - topks1 := []int64{int64(len(ids1))} - ids2 := []int64{5, 1, 3, 4} - scores2 := []float32{-1.0, -1.0, -3.0, -4.0} - topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) - data1.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_VarChar, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_StringData{ - StringData: &schemapb.StringArray{ - Data: []string{"1", "2", "3", "4"}, - }, - }, - }, - }, - } - data2.GroupByFieldValue = &schemapb.FieldData{ - Type: schemapb.DataType_VarChar, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_StringData{ - StringData: &schemapb.StringArray{ - Data: []string{"1", "2", "3", "4"}, - }, - }, - }, - }, - } - dataArray := make([]*schemapb.SearchResultData, 0) - dataArray = append(dataArray, data1) - dataArray = append(dataArray, data2) - res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk) - suite.Nil(err) - suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data) - suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores) - suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data) - }) -} - func (suite *ResultSuite) TestResult_SelectSearchResultData_int() { type args struct { dataArray []*schemapb.SearchResultData diff --git a/internal/querynodev2/segments/search_reduce.go b/internal/querynodev2/segments/search_reduce.go new file mode 100644 index 0000000000..14dff5fc7c --- /dev/null +++ b/internal/querynodev2/segments/search_reduce.go @@ -0,0 +1,227 @@ +package segments + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ReduceInfo struct { + nq int64 + topK int64 + groupByFieldID int64 + groupSize int64 + metricType string +} + +func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo { + return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric} +} + +type SearchReduce interface { + ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) +} + +type SearchCommonReduce struct{} + +func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) { + ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData") + defer sp.End() + log := log.Ctx(ctx) + + if len(searchResultData) == 0 { + return &schemapb.SearchResultData{ + NumQueries: info.nq, + TopK: info.topK, + FieldsData: make([]*schemapb.FieldData, 0), + Scores: make([]float32, 0), + Ids: &schemapb.IDs{}, + Topks: make([]int64, 0), + }, nil + } + ret := &schemapb.SearchResultData{ + NumQueries: info.nq, + TopK: info.topK, + FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), + Scores: make([]float32, 0), + Ids: &schemapb.IDs{}, + Topks: make([]int64, 0), + } + + resultOffsets := make([][]int64, len(searchResultData)) + for i := 0; i < len(searchResultData); i++ { + resultOffsets[i] = make([]int64, len(searchResultData[i].Topks)) + for j := int64(1); j < info.nq; j++ { + resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1] + } + ret.AllSearchCount += searchResultData[i].GetAllSearchCount() + } + + var skipDupCnt int64 + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + for i := int64(0); i < info.nq; i++ { + offsets := make([]int64, len(searchResultData)) + idSet := make(map[interface{}]struct{}) + var j int64 + for j = 0; j < info.topK; { + sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i) + if sel == -1 { + break + } + idx := resultOffsets[sel][i] + offsets[sel] + + id := typeutil.GetPK(searchResultData[sel].GetIds(), idx) + score := searchResultData[sel].Scores[idx] + + // remove duplicates + if _, ok := idSet[id]; !ok { + retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) + typeutil.AppendPKs(ret.Ids, id) + ret.Scores = append(ret.Scores, score) + idSet[id] = struct{}{} + j++ + } else { + // skip entity with same id + skipDupCnt++ + } + offsets[sel]++ + } + + // if realTopK != -1 && realTopK != j { + // log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + // // return nil, errors.New("the length (topk) between all result of query is different") + // } + ret.Topks = append(ret.Topks, j) + + // limit search result to avoid oom + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) + } + } + log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + return ret, nil +} + +type SearchGroupByReduce struct{} + +func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) { + ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData") + defer sp.End() + log := log.Ctx(ctx) + + if len(searchResultData) == 0 { + return &schemapb.SearchResultData{ + NumQueries: info.nq, + TopK: info.topK, + FieldsData: make([]*schemapb.FieldData, 0), + Scores: make([]float32, 0), + Ids: &schemapb.IDs{}, + Topks: make([]int64, 0), + }, nil + } + ret := &schemapb.SearchResultData{ + NumQueries: info.nq, + TopK: info.topK, + FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), + Scores: make([]float32, 0), + Ids: &schemapb.IDs{}, + Topks: make([]int64, 0), + } + + resultOffsets := make([][]int64, len(searchResultData)) + for i := 0; i < len(searchResultData); i++ { + resultOffsets[i] = make([]int64, len(searchResultData[i].Topks)) + for j := int64(1); j < info.nq; j++ { + resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1] + } + ret.AllSearchCount += searchResultData[i].GetAllSearchCount() + } + + var filteredCount int64 + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + groupSize := info.groupSize + if groupSize <= 0 { + groupSize = 1 + } + groupBound := info.topK * groupSize + + for i := int64(0); i < info.nq; i++ { + offsets := make([]int64, len(searchResultData)) + + idSet := make(map[interface{}]struct{}) + groupByValueMap := make(map[interface{}]int64) + + var j int64 + for j = 0; j < groupBound; { + sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i) + if sel == -1 { + break + } + idx := resultOffsets[sel][i] + offsets[sel] + + id := typeutil.GetPK(searchResultData[sel].GetIds(), idx) + groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx)) + score := searchResultData[sel].Scores[idx] + if _, ok := idSet[id]; !ok { + if groupByVal == nil { + return ret, merr.WrapErrParameterMissing("GroupByVal returned from segment cannot be null") + } + + groupCount := groupByValueMap[groupByVal] + if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK { + // exceed the limit for group count, filter this entity + filteredCount++ + } else if groupCount >= groupSize { + // exceed the limit for each group, filter this entity + filteredCount++ + } else { + retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) + typeutil.AppendPKs(ret.Ids, id) + ret.Scores = append(ret.Scores, score) + if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil { + log.Error("Failed to append groupByValues", zap.Error(err)) + return ret, err + } + groupByValueMap[groupByVal] += 1 + idSet[id] = struct{}{} + j++ + } + } else { + // skip entity with same pk + filteredCount++ + } + offsets[sel]++ + } + ret.Topks = append(ret.Topks, j) + + // limit search result to avoid oom + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) + } + } + if float64(filteredCount) >= 0.3*float64(groupBound) { + log.Warn("GroupBy reduce filtered too many results, "+ + "this may influence the final result seriously", + zap.Int64("filteredCount", filteredCount), + zap.Int64("groupBound", groupBound)) + } + log.Debug("skip duplicated search result", zap.Int64("count", filteredCount)) + return ret, nil +} + +func InitSearchReducer(info *ReduceInfo) SearchReduce { + if info.groupByFieldID > 0 { + return &SearchGroupByReduce{} + } + return &SearchCommonReduce{} +} diff --git a/internal/querynodev2/segments/search_reduce_test.go b/internal/querynodev2/segments/search_reduce_test.go new file mode 100644 index 0000000000..05eb4a058f --- /dev/null +++ b/internal/querynodev2/segments/search_reduce_test.go @@ -0,0 +1,258 @@ +package segments + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type SearchReduceSuite struct { + suite.Suite +} + +func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() { + const ( + nq = 1 + topk = 4 + ) + suite.Run("case1", func() { + ids := []int64{1, 2, 3, 4} + scores := []float32{-1.0, -2.0, -3.0, -4.0} + topks := []int64{int64(len(ids))} + data1 := genSearchResultData(nq, topk, ids, scores, topks) + data2 := genSearchResultData(nq, topk, ids, scores, topks) + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.Equal(ids, res.Ids.GetIntId().Data) + suite.Equal(scores, res.Scores) + }) + suite.Run("case2", func() { + ids1 := []int64{1, 2, 3, 4} + scores1 := []float32{-1.0, -2.0, -3.0, -4.0} + topks1 := []int64{int64(len(ids1))} + ids2 := []int64{5, 1, 3, 4} + scores2 := []float32{-1.0, -1.0, -3.0, -4.0} + topks2 := []int64{int64(len(ids2))} + data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.ElementsMatch([]int64{1, 5, 2, 3}, res.Ids.GetIntId().Data) + }) +} + +func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { + const ( + nq = 1 + topk = 4 + ) + suite.Run("reduce_group_by_int", func() { + ids1 := []int64{1, 2, 3, 4} + scores1 := []float32{-1.0, -2.0, -3.0, -4.0} + topks1 := []int64{int64(len(ids1))} + ids2 := []int64{5, 1, 3, 4} + scores2 := []float32{-1.0, -1.0, -3.0, -4.0} + topks2 := []int64{int64(len(ids2))} + data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Int8, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{2, 3, 4, 5}, + }, + }, + }, + }, + } + data2.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Int8, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{2, 3, 4, 5}, + }, + }, + }, + }, + } + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data) + suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores) + suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data) + }) + suite.Run("reduce_group_by_bool", func() { + ids1 := []int64{1, 2} + scores1 := []float32{-1.0, -2.0} + topks1 := []int64{int64(len(ids1))} + ids2 := []int64{3, 4} + scores2 := []float32{-1.0, -1.0} + topks2 := []int64{int64(len(ids2))} + data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false}, + }, + }, + }, + }, + } + data2.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false}, + }, + }, + }, + }, + } + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data) + suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores) + suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data) + }) + suite.Run("reduce_group_by_string", func() { + ids1 := []int64{1, 2, 3, 4} + scores1 := []float32{-1.0, -2.0, -3.0, -4.0} + topks1 := []int64{int64(len(ids1))} + ids2 := []int64{5, 1, 3, 4} + scores2 := []float32{-1.0, -1.0, -3.0, -4.0} + topks2 := []int64{int64(len(ids2))} + data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "2", "3", "4"}, + }, + }, + }, + }, + } + data2.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "2", "3", "4"}, + }, + }, + }, + }, + } + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data) + suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores) + suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data) + }) + suite.Run("reduce_group_by_string_with_group_size", func() { + ids1 := []int64{1, 2, 3, 4} + scores1 := []float32{-1.0, -2.0, -3.0, -4.0} + topks1 := []int64{int64(len(ids1))} + ids2 := []int64{4, 5, 6, 7} + scores2 := []float32{-1.0, -1.0, -3.0, -4.0} + topks2 := []int64{int64(len(ids2))} + data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "2", "3", "4"}, + }, + }, + }, + }, + } + data2.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "2", "3", "4"}, + }, + }, + }, + }, + } + dataArray := make([]*schemapb.SearchResultData, 0) + dataArray = append(dataArray, data1) + dataArray = append(dataArray, data2) + reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.ElementsMatch([]int64{1, 4, 2, 5, 3, 6, 7}, res.Ids.GetIntId().Data) + suite.ElementsMatch([]float32{-1.0, -1.0, -1.0, -2.0, -3.0, -3.0, -4.0}, res.Scores) + suite.ElementsMatch([]string{"1", "1", "2", "2", "3", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data) + }) + + suite.Run("reduce_group_by_empty_input", func() { + dataArray := make([]*schemapb.SearchResultData, 0) + reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3} + searchReduce := InitSearchReducer(reduceInfo) + res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) + suite.Nil(err) + suite.Nil(res.GetIds().GetIdField()) + suite.Equal(0, len(res.GetTopks())) + suite.Equal(0, len(res.GetScores())) + suite.Equal(int64(nq), res.GetNumQueries()) + suite.Equal(int64(topk), res.GetTopK()) + suite.Equal(0, len(res.GetFieldsData())) + }) +} + +func TestSearchReduce(t *testing.T) { + paramtable.Init() + suite.Run(t, new(SearchReduceSuite)) +} diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index b367e03c95..e597d07672 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -790,7 +790,11 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( if req.GetReq().GetIsAdvanced() { result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq()) } else { - result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) + result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(), + req.Req.GetTopk(), + req.Req.GetExtraSearchParam().GetGroupByFieldId(), + req.Req.GetExtraSearchParam().GetGroupSize(), + req.Req.GetMetricType())) } if err2 != nil {