mirror of https://github.com/milvus-io/milvus.git
related: #33544 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/35730/head
parent
fe83805d56
commit
bfd9d86fe9
|
@ -96,6 +96,11 @@ message SubSearchRequest {
|
|||
string metricType = 9;
|
||||
}
|
||||
|
||||
message ExtraSearchParam {
|
||||
int64 group_by_field_id = 1;
|
||||
int64 group_size = 2;
|
||||
}
|
||||
|
||||
message SearchRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 reqID = 2;
|
||||
|
@ -120,6 +125,7 @@ message SearchRequest {
|
|||
bool is_advanced = 20;
|
||||
int64 offset = 21;
|
||||
common.ConsistencyLevel consistency_level = 22;
|
||||
ExtraSearchParam extra_search_param = 23;
|
||||
}
|
||||
|
||||
message SubSearchResults {
|
||||
|
|
|
@ -58,7 +58,8 @@ func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo)
|
|||
reduceInfo.topK,
|
||||
reduceInfo.metricType,
|
||||
reduceInfo.pkType,
|
||||
reduceInfo.offset)
|
||||
reduceInfo.offset,
|
||||
reduceInfo.queryInfo.GroupSize)
|
||||
}
|
||||
return reduceSearchResultDataNoGroupBy(ctx,
|
||||
reduceInfo.subSearchResultData,
|
||||
|
@ -69,7 +70,21 @@ func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo)
|
|||
reduceInfo.offset)
|
||||
}
|
||||
|
||||
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
||||
type MilvusPKType interface{}
|
||||
|
||||
type groupReduceInfo struct {
|
||||
subSearchIdx int
|
||||
resultIdx int64
|
||||
score float32
|
||||
id MilvusPKType
|
||||
}
|
||||
|
||||
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||
nq int64, topk int64, metricType string,
|
||||
pkType schemapb.DataType,
|
||||
offset int64,
|
||||
groupSize int64,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
|
@ -131,12 +146,14 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
totalResCount int64 = 0
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
totalResCount += subSearchNqOffset[i][nq-1]
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -146,6 +163,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
groupBound := groupSize * limit
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
|
@ -155,15 +173,14 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
groupByValSet = make(map[interface{}]struct{})
|
||||
pkSet = make(map[interface{}]struct{})
|
||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||
skipOffsetMap = make(map[interface{}]bool)
|
||||
groupByValList = make([]interface{}, limit)
|
||||
groupByValIdx = 0
|
||||
)
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
for j = 0; j < groupBound; {
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
|
@ -171,44 +188,63 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.Scores[resultDataIdx]
|
||||
score := subSearchRes.GetScores()[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
_, groupByValExist := groupByValSet[groupByVal]
|
||||
if !groupByValExist {
|
||||
groupByValSet[groupByVal] = struct{}{}
|
||||
if int64(len(groupByValSet)) <= offset {
|
||||
continue
|
||||
// skip offset groups
|
||||
if _, ok := pkSet[id]; !ok {
|
||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||
skipOffsetMap[groupByVal] = true
|
||||
// the first offset's group will be ignored
|
||||
skipDupCnt++
|
||||
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
||||
// skip when groupbyMap has been full and found new groupByVal
|
||||
skipDupCnt++
|
||||
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
||||
// skip when target group has been full
|
||||
skipDupCnt++
|
||||
} else {
|
||||
if len(groupByValMap[groupByVal]) == 0 {
|
||||
groupByValList[groupByValIdx] = groupByVal
|
||||
groupByValIdx++
|
||||
}
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
|
||||
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
|
||||
subSearchIdx: subSearchIdx,
|
||||
resultIdx: resultDataIdx, id: id, score: score,
|
||||
})
|
||||
pkSet[id] = struct{}{}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
skipDupCnt++
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// assemble all eligible values in group
|
||||
// values in groupByValList is sorted by the highest score in each group
|
||||
for _, groupVal := range groupByValList {
|
||||
if groupVal != nil {
|
||||
groupEntities := groupByValMap[groupVal]
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
j++
|
||||
} else {
|
||||
// skip entity with same groupby
|
||||
skipDupCnt++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
@ -218,10 +254,13 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
log.Ctx(ctx).Debug("skip duplicated search result when doing group by", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Ctx(ctx).Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
if float64(skipDupCnt) >= float64(totalResCount)*0.3 {
|
||||
log.Warn("GroupBy reduce skipped too many results, "+
|
||||
"this may influence the final result seriously",
|
||||
zap.Int64("skipDupCnt", skipDupCnt),
|
||||
zap.Int64("groupBound", groupBound))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
|
|
|
@ -99,7 +99,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
searchParamStr = ""
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
// 5. parse group by field and group by size
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
|
@ -118,7 +118,18 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
}
|
||||
|
||||
// 6. disable groupBy for iterator and range search
|
||||
var groupSize int64
|
||||
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupSize = 1
|
||||
} else {
|
||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||
if err != nil || groupSize <= 0 {
|
||||
groupSize = 1
|
||||
}
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")
|
||||
|
@ -134,6 +145,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ const (
|
|||
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||
IteratorField = "iterator"
|
||||
GroupByFieldKey = "group_by_field"
|
||||
GroupSizeKey = "group_size"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
|
|
|
@ -456,12 +456,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.queryInfos = append(t.queryInfos, queryInfo)
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize}
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
|
|
|
@ -1706,6 +1706,7 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||
}
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: 1,
|
||||
GroupSize: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
||||
schemapb.DataType_Int64, 0, queryInfo))
|
||||
|
@ -1765,6 +1766,7 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
|
|||
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: 1,
|
||||
GroupSize: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
|
||||
schemapb.DataType_Int64, offset, queryInfo))
|
||||
|
@ -1777,6 +1779,82 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
|
||||
var (
|
||||
nq int64 = 2
|
||||
topK int64 = 5
|
||||
)
|
||||
ids := [][]int64{
|
||||
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
||||
{2, 4, 6, 8, 10, 2, 4, 6, 8, 10},
|
||||
}
|
||||
scores := [][]float32{
|
||||
{10, 8, 6, 4, 2, 10, 8, 6, 4, 2},
|
||||
{9, 7, 5, 3, 1, 9, 7, 5, 3, 1},
|
||||
}
|
||||
|
||||
groupByValuesArr := [][][]int64{
|
||||
{
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
},
|
||||
{
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
{6, 8, 3, 4, 5, 6, 8, 3, 4, 5},
|
||||
},
|
||||
}
|
||||
expectedIDs := [][]int64{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
|
||||
}
|
||||
expectedScores := [][]float32{
|
||||
{-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1},
|
||||
{-10, -9, -8, -7, -6, -5, -10, -9, -8, -7, -6, -5},
|
||||
}
|
||||
expectedGroupByValues := [][]int64{
|
||||
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5},
|
||||
{1, 6, 2, 8, 3, 3, 1, 6, 2, 8, 3, 3},
|
||||
}
|
||||
|
||||
for i, groupByValues := range groupByValuesArr {
|
||||
t.Run("Group By correctness", func(t *testing.T) {
|
||||
var results []*schemapb.SearchResultData
|
||||
for j := range ids {
|
||||
result := getSearchResultData(nq, topK)
|
||||
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
|
||||
result.Scores = scores[j]
|
||||
result.Topks = []int64{topK, topK}
|
||||
result.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: groupByValues[j],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: 1,
|
||||
GroupSize: 2,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
||||
schemapb.DataType_Int64, 0, queryInfo))
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
assert.EqualValues(t, expectedIDs[i], resultIDs)
|
||||
assert.EqualValues(t, expectedScores[i], resultScores)
|
||||
assert.EqualValues(t, expectedGroupByValues[i], resultGroupByValues)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
|
|
|
@ -351,9 +351,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
|
||||
return segments.ReduceSearchResults(ctx,
|
||||
results,
|
||||
searchReq.Req.GetNq(),
|
||||
segments.NewReduceInfo(searchReq.Req.GetNq(),
|
||||
searchReq.Req.GetTopk(),
|
||||
searchReq.Req.GetMetricType())
|
||||
searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
searchReq.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
searchReq.Req.GetMetricType()),
|
||||
)
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
|
|
@ -388,7 +388,11 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||
if req.GetReq().GetIsAdvanced() {
|
||||
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
|
||||
} else {
|
||||
resp, err = segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(),
|
||||
req.Req.GetTopk(),
|
||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
req.Req.GetMetricType()))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -42,7 +42,7 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}
|
|||
|
||||
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
|
||||
|
||||
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) {
|
||||
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) {
|
||||
results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
|
||||
return result != nil && result.GetSlicedBlob() != nil
|
||||
})
|
||||
|
@ -60,8 +60,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
channelsMvcc[ch] = ts
|
||||
}
|
||||
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
|
||||
if metricType == "" {
|
||||
metricType = r.MetricType
|
||||
if info.metricType == "" {
|
||||
info.metricType = r.MetricType
|
||||
}
|
||||
}
|
||||
log := log.Ctx(ctx)
|
||||
|
@ -80,12 +80,13 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
zap.Int64("topk", sData.TopK))
|
||||
}
|
||||
|
||||
reducedResultData, err := ReduceSearchResultData(ctx, searchResultData, nq, topk)
|
||||
searchReduce := InitSearchReducer(info)
|
||||
reducedResultData, err := searchReduce.ReduceSearchResultData(ctx, searchResultData, info)
|
||||
if err != nil {
|
||||
log.Warn("shard leader reduce errors", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, nq, topk, metricType)
|
||||
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType)
|
||||
if err != nil {
|
||||
log.Warn("shard leader encode search result errors", zap.Error(err))
|
||||
return nil, err
|
||||
|
@ -207,101 +208,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes
|
|||
return searchResults, nil
|
||||
}
|
||||
|
||||
func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
if len(searchResultData) == 0 {
|
||||
return &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}, nil
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
for j := int64(1); j < nq; j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
}
|
||||
|
||||
var skipDupCnt int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for i := int64(0); i < nq; i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
groupByValueSet := make(map[interface{}]struct{})
|
||||
var j int64
|
||||
for j = 0; j < topk; {
|
||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
}
|
||||
idx := resultOffsets[sel][i] + offsets[sel]
|
||||
|
||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
|
||||
score := searchResultData[sel].Scores[idx]
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
groupByValExist := false
|
||||
if groupByVal != nil {
|
||||
_, groupByValExist = groupByValueSet[groupByVal]
|
||||
}
|
||||
if !groupByValExist {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if groupByVal != nil {
|
||||
groupByValueSet[groupByVal] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Error("Failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
offsets[sel]++
|
||||
}
|
||||
|
||||
// if realTopK != -1 && realTopK != j {
|
||||
// log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// // return nil, errors.New("the length (topk) between all result of query is different")
|
||||
// }
|
||||
ret.Topks = append(ret.Topks, j)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func SelectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
|
||||
var (
|
||||
sel = -1
|
||||
|
|
|
@ -702,177 +702,6 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
|
|||
})
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestResult_ReduceSearchResultData() {
|
||||
const (
|
||||
nq = 1
|
||||
topk = 4
|
||||
metricType = "L2"
|
||||
)
|
||||
suite.Run("case1", func() {
|
||||
ids := []int64{1, 2, 3, 4}
|
||||
scores := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks := []int64{int64(len(ids))}
|
||||
data1 := genSearchResultData(nq, topk, ids, scores, topks)
|
||||
data2 := genSearchResultData(nq, topk, ids, scores, topks)
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.Equal(ids, res.Ids.GetIntId().Data)
|
||||
suite.Equal(scores, res.Scores)
|
||||
})
|
||||
suite.Run("case2", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestResult_SearchGroupByResult() {
|
||||
const (
|
||||
nq = 1
|
||||
topk = 4
|
||||
)
|
||||
suite.Run("reduce_group_by_int", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_bool", func() {
|
||||
ids1 := []int64{1, 2}
|
||||
scores1 := []float32{-1.0, -2.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{3, 4}
|
||||
scores2 := []float32{-1.0, -1.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores)
|
||||
suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_string", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestResult_SelectSearchResultData_int() {
|
||||
type args struct {
|
||||
dataArray []*schemapb.SearchResultData
|
||||
|
|
|
@ -0,0 +1,227 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type ReduceInfo struct {
|
||||
nq int64
|
||||
topK int64
|
||||
groupByFieldID int64
|
||||
groupSize int64
|
||||
metricType string
|
||||
}
|
||||
|
||||
func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo {
|
||||
return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric}
|
||||
}
|
||||
|
||||
type SearchReduce interface {
|
||||
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error)
|
||||
}
|
||||
|
||||
type SearchCommonReduce struct{}
|
||||
|
||||
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
if len(searchResultData) == 0 {
|
||||
return &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}, nil
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
for j := int64(1); j < info.nq; j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
}
|
||||
|
||||
var skipDupCnt int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for i := int64(0); i < info.nq; i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
idSet := make(map[interface{}]struct{})
|
||||
var j int64
|
||||
for j = 0; j < info.topK; {
|
||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
}
|
||||
idx := resultOffsets[sel][i] + offsets[sel]
|
||||
|
||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||
score := searchResultData[sel].Scores[idx]
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
offsets[sel]++
|
||||
}
|
||||
|
||||
// if realTopK != -1 && realTopK != j {
|
||||
// log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// // return nil, errors.New("the length (topk) between all result of query is different")
|
||||
// }
|
||||
ret.Topks = append(ret.Topks, j)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type SearchGroupByReduce struct{}
|
||||
|
||||
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
if len(searchResultData) == 0 {
|
||||
return &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}, nil
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
for j := int64(1); j < info.nq; j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
}
|
||||
|
||||
var filteredCount int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
groupSize := info.groupSize
|
||||
if groupSize <= 0 {
|
||||
groupSize = 1
|
||||
}
|
||||
groupBound := info.topK * groupSize
|
||||
|
||||
for i := int64(0); i < info.nq; i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
groupByValueMap := make(map[interface{}]int64)
|
||||
|
||||
var j int64
|
||||
for j = 0; j < groupBound; {
|
||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
}
|
||||
idx := resultOffsets[sel][i] + offsets[sel]
|
||||
|
||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
|
||||
score := searchResultData[sel].Scores[idx]
|
||||
if _, ok := idSet[id]; !ok {
|
||||
if groupByVal == nil {
|
||||
return ret, merr.WrapErrParameterMissing("GroupByVal returned from segment cannot be null")
|
||||
}
|
||||
|
||||
groupCount := groupByValueMap[groupByVal]
|
||||
if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK {
|
||||
// exceed the limit for group count, filter this entity
|
||||
filteredCount++
|
||||
} else if groupCount >= groupSize {
|
||||
// exceed the limit for each group, filter this entity
|
||||
filteredCount++
|
||||
} else {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Error("Failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
groupByValueMap[groupByVal] += 1
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same pk
|
||||
filteredCount++
|
||||
}
|
||||
offsets[sel]++
|
||||
}
|
||||
ret.Topks = append(ret.Topks, j)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
if float64(filteredCount) >= 0.3*float64(groupBound) {
|
||||
log.Warn("GroupBy reduce filtered too many results, "+
|
||||
"this may influence the final result seriously",
|
||||
zap.Int64("filteredCount", filteredCount),
|
||||
zap.Int64("groupBound", groupBound))
|
||||
}
|
||||
log.Debug("skip duplicated search result", zap.Int64("count", filteredCount))
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func InitSearchReducer(info *ReduceInfo) SearchReduce {
|
||||
if info.groupByFieldID > 0 {
|
||||
return &SearchGroupByReduce{}
|
||||
}
|
||||
return &SearchCommonReduce{}
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
type SearchReduceSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
|
||||
const (
|
||||
nq = 1
|
||||
topk = 4
|
||||
)
|
||||
suite.Run("case1", func() {
|
||||
ids := []int64{1, 2, 3, 4}
|
||||
scores := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks := []int64{int64(len(ids))}
|
||||
data1 := genSearchResultData(nq, topk, ids, scores, topks)
|
||||
data2 := genSearchResultData(nq, topk, ids, scores, topks)
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.Equal(ids, res.Ids.GetIntId().Data)
|
||||
suite.Equal(scores, res.Scores)
|
||||
})
|
||||
suite.Run("case2", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
||||
const (
|
||||
nq = 1
|
||||
topk = 4
|
||||
)
|
||||
suite.Run("reduce_group_by_int", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_bool", func() {
|
||||
ids1 := []int64{1, 2}
|
||||
scores1 := []float32{-1.0, -2.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{3, 4}
|
||||
scores2 := []float32{-1.0, -1.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores)
|
||||
suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_string", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_string_with_group_size", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{4, 5, 6, 7}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 4, 2, 5, 3, 6, 7}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -1.0, -1.0, -2.0, -3.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]string{"1", "1", "2", "2", "3", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
|
||||
})
|
||||
|
||||
suite.Run("reduce_group_by_empty_input", func() {
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
suite.Nil(res.GetIds().GetIdField())
|
||||
suite.Equal(0, len(res.GetTopks()))
|
||||
suite.Equal(0, len(res.GetScores()))
|
||||
suite.Equal(int64(nq), res.GetNumQueries())
|
||||
suite.Equal(int64(topk), res.GetTopK())
|
||||
suite.Equal(0, len(res.GetFieldsData()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSearchReduce(t *testing.T) {
|
||||
paramtable.Init()
|
||||
suite.Run(t, new(SearchReduceSuite))
|
||||
}
|
|
@ -790,7 +790,11 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
if req.GetReq().GetIsAdvanced() {
|
||||
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
|
||||
} else {
|
||||
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(),
|
||||
req.Req.GetTopk(),
|
||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
req.Req.GetMetricType()))
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
|
|
Loading…
Reference in New Issue