mirror of https://github.com/milvus-io/milvus.git
				
				
				
			enhance: simplify reduction on single search result (#36334)
See: #36122 --------- Signed-off-by: Ted Xu <ted.xu@zilliz.com>pull/36278/head
							parent
							
								
									89397d1e66
								
							
						
					
					
						commit
						363004fd44
					
				| 
						 | 
					@ -218,14 +218,13 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
 | 
				
			||||||
		totalResCount += subSearchNqOffset[i][nq-1]
 | 
							totalResCount += subSearchNqOffset[i][nq-1]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						if subSearchNum == 1 && offset == 0 {
 | 
				
			||||||
		skipDupCnt int64
 | 
							ret.Results = subSearchResultData[0]
 | 
				
			||||||
		realTopK   int64 = -1
 | 
						} else {
 | 
				
			||||||
	)
 | 
							var realTopK int64 = -1
 | 
				
			||||||
 | 
					 | 
				
			||||||
		var retSize int64
 | 
							var retSize int64
 | 
				
			||||||
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
				
			||||||
		// reducing nq * topk results
 | 
							// reducing nq * topk results
 | 
				
			||||||
		for i := int64(0); i < nq; i++ {
 | 
							for i := int64(0); i < nq; i++ {
 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
| 
						 | 
					@ -234,7 +233,6 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
 | 
				
			||||||
				cursors = make([]int64, subSearchNum)
 | 
									cursors = make([]int64, subSearchNum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				j              int64
 | 
									j              int64
 | 
				
			||||||
			pkSet          = make(map[interface{}]struct{})
 | 
					 | 
				
			||||||
				groupByValMap  = make(map[interface{}][]*groupReduceInfo)
 | 
									groupByValMap  = make(map[interface{}][]*groupReduceInfo)
 | 
				
			||||||
				skipOffsetMap  = make(map[interface{}]bool)
 | 
									skipOffsetMap  = make(map[interface{}]bool)
 | 
				
			||||||
				groupByValList = make([]interface{}, limit)
 | 
									groupByValList = make([]interface{}, limit)
 | 
				
			||||||
| 
						 | 
					@ -255,18 +253,14 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
 | 
				
			||||||
					return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
 | 
										return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
 | 
				
			||||||
						"there must be sth wrong on queryNode side")
 | 
											"there must be sth wrong on queryNode side")
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			// remove duplicates
 | 
					
 | 
				
			||||||
			if _, ok := pkSet[id]; !ok {
 | 
					 | 
				
			||||||
				if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
 | 
									if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
 | 
				
			||||||
					skipOffsetMap[groupByVal] = true
 | 
										skipOffsetMap[groupByVal] = true
 | 
				
			||||||
					// the first offset's group will be ignored
 | 
										// the first offset's group will be ignored
 | 
				
			||||||
					skipDupCnt++
 | 
					 | 
				
			||||||
				} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
 | 
									} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
 | 
				
			||||||
					// skip when groupbyMap has been full and found new groupByVal
 | 
										// skip when groupbyMap has been full and found new groupByVal
 | 
				
			||||||
					skipDupCnt++
 | 
					 | 
				
			||||||
				} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
 | 
									} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
 | 
				
			||||||
					// skip when target group has been full
 | 
										// skip when target group has been full
 | 
				
			||||||
					skipDupCnt++
 | 
					 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					if len(groupByValMap[groupByVal]) == 0 {
 | 
										if len(groupByValMap[groupByVal]) == 0 {
 | 
				
			||||||
						groupByValList[groupByValIdx] = groupByVal
 | 
											groupByValList[groupByValIdx] = groupByVal
 | 
				
			||||||
| 
						 | 
					@ -276,12 +270,8 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
 | 
				
			||||||
						subSearchIdx: subSearchIdx,
 | 
											subSearchIdx: subSearchIdx,
 | 
				
			||||||
						resultIdx:    resultDataIdx, id: id, score: score,
 | 
											resultIdx:    resultDataIdx, id: id, score: score,
 | 
				
			||||||
					})
 | 
										})
 | 
				
			||||||
					pkSet[id] = struct{}{}
 | 
					 | 
				
			||||||
					j++
 | 
										j++
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				skipDupCnt++
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
				cursors[subSearchIdx]++
 | 
									cursors[subSearchIdx]++
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -315,16 +305,8 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
 | 
				
			||||||
				return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
 | 
									return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	log.Ctx(ctx).Debug("skip duplicated search result when doing group by", zap.Int64("count", skipDupCnt))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if float64(skipDupCnt) >= float64(totalResCount)*0.3 {
 | 
					 | 
				
			||||||
		log.Warn("GroupBy reduce skipped too many results, "+
 | 
					 | 
				
			||||||
			"this may influence the final result seriously",
 | 
					 | 
				
			||||||
			zap.Int64("skipDupCnt", skipDupCnt),
 | 
					 | 
				
			||||||
			zap.Int64("groupBound", groupBound))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
 | 
							ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	if !metric.PositivelyRelated(metricType) {
 | 
						if !metric.PositivelyRelated(metricType) {
 | 
				
			||||||
		for k := range ret.Results.Scores {
 | 
							for k := range ret.Results.Scores {
 | 
				
			||||||
			ret.Results.Scores[k] *= -1
 | 
								ret.Results.Scores[k] *= -1
 | 
				
			||||||
| 
						 | 
					@ -370,35 +352,36 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
 | 
				
			||||||
		ret.GetResults().AllSearchCount = allSearchCount
 | 
							ret.GetResults().AllSearchCount = allSearchCount
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						subSearchNum := len(subSearchResultData)
 | 
				
			||||||
		subSearchNum = len(subSearchResultData)
 | 
						if subSearchNum == 1 && offset == 0 {
 | 
				
			||||||
 | 
							// sorting is not needed if there is only one shard and no offset, assigning the result directly.
 | 
				
			||||||
 | 
							//  we still need to adjust the scores later.
 | 
				
			||||||
 | 
							ret.Results = subSearchResultData[0]
 | 
				
			||||||
 | 
							// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
 | 
				
			||||||
 | 
							topks := subSearchResultData[0].Topks
 | 
				
			||||||
 | 
							if len(topks) > 0 {
 | 
				
			||||||
 | 
								ret.Results.TopK = topks[len(topks)-1]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							var realTopK int64 = -1
 | 
				
			||||||
 | 
							var retSize int64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// for results of each subSearchResultData, storing the start offset of each query of nq queries
 | 
							// for results of each subSearchResultData, storing the start offset of each query of nq queries
 | 
				
			||||||
		subSearchNqOffset = make([][]int64, subSearchNum)
 | 
							subSearchNqOffset := make([][]int64, subSearchNum)
 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
		for i := 0; i < subSearchNum; i++ {
 | 
							for i := 0; i < subSearchNum; i++ {
 | 
				
			||||||
			subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
 | 
								subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
 | 
				
			||||||
			for j := int64(1); j < nq; j++ {
 | 
								for j := int64(1); j < nq; j++ {
 | 
				
			||||||
				subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
 | 
									subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		skipDupCnt int64
 | 
					 | 
				
			||||||
		realTopK   int64 = -1
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var retSize int64
 | 
					 | 
				
			||||||
		maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
							maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// reducing nq * topk results
 | 
							// reducing nq * topk results
 | 
				
			||||||
		for i := int64(0); i < nq; i++ {
 | 
							for i := int64(0); i < nq; i++ {
 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
				// cursor of current data of each subSearch for merging the j-th data of TopK.
 | 
									// cursor of current data of each subSearch for merging the j-th data of TopK.
 | 
				
			||||||
				// sum(cursors) == j
 | 
									// sum(cursors) == j
 | 
				
			||||||
				cursors = make([]int64, subSearchNum)
 | 
									cursors = make([]int64, subSearchNum)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				j       int64
 | 
									j       int64
 | 
				
			||||||
			idSet = make(map[interface{}]struct{}, limit)
 | 
					 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// skip offset results
 | 
								// skip offset results
 | 
				
			||||||
| 
						 | 
					@ -412,7 +395,7 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// keep limit results
 | 
								// keep limit results
 | 
				
			||||||
		for j = 0; j < limit; {
 | 
								for j = 0; j < limit; j++ {
 | 
				
			||||||
				// From all the sub-query result sets of the i-th query vector,
 | 
									// From all the sub-query result sets of the i-th query vector,
 | 
				
			||||||
				//   find the sub-query result set index of the score j-th data,
 | 
									//   find the sub-query result set index of the score j-th data,
 | 
				
			||||||
				//   and the index of the data in schemapb.SearchResultData
 | 
									//   and the index of the data in schemapb.SearchResultData
 | 
				
			||||||
| 
						 | 
					@ -420,20 +403,11 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
 | 
				
			||||||
				if subSearchIdx == -1 {
 | 
									if subSearchIdx == -1 {
 | 
				
			||||||
					break
 | 
										break
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
 | 
					 | 
				
			||||||
				score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
 | 
									score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// remove duplicatessds
 | 
					 | 
				
			||||||
			if _, ok := idSet[id]; !ok {
 | 
					 | 
				
			||||||
				retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
 | 
									retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
 | 
				
			||||||
				typeutil.AppendPKs(ret.Results.Ids, id)
 | 
									typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
 | 
				
			||||||
				ret.Results.Scores = append(ret.Results.Scores, score)
 | 
									ret.Results.Scores = append(ret.Results.Scores, score)
 | 
				
			||||||
				idSet[id] = struct{}{}
 | 
					 | 
				
			||||||
				j++
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				// skip entity with same id
 | 
					 | 
				
			||||||
				skipDupCnt++
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
				cursors[subSearchIdx]++
 | 
									cursors[subSearchIdx]++
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if realTopK != -1 && realTopK != j {
 | 
								if realTopK != -1 && realTopK != j {
 | 
				
			||||||
| 
						 | 
					@ -448,13 +422,9 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
 | 
				
			||||||
				return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
 | 
									return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
 | 
							ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if skipDupCnt > 0 {
 | 
					 | 
				
			||||||
		log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
 | 
					 | 
				
			||||||
	if !metric.PositivelyRelated(metricType) {
 | 
						if !metric.PositivelyRelated(metricType) {
 | 
				
			||||||
		for k := range ret.Results.Scores {
 | 
							for k := range ret.Results.Scores {
 | 
				
			||||||
			ret.Results.Scores[k] *= -1
 | 
								ret.Results.Scores[k] *= -1
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,7 @@ type SearchReduceUtilTestSuite struct {
 | 
				
			||||||
	suite.Suite
 | 
						suite.Suite
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
 | 
					func genTestDataSearchResultsData() []*schemapb.SearchResultData {
 | 
				
			||||||
	var searchResultData1 *schemapb.SearchResultData
 | 
						var searchResultData1 *schemapb.SearchResultData
 | 
				
			||||||
	var searchResultData2 *schemapb.SearchResultData
 | 
						var searchResultData2 *schemapb.SearchResultData
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,10 +49,14 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
 | 
				
			||||||
			GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
 | 
								GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return []*schemapb.SearchResultData{searchResultData1, searchResultData2}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
 | 
				
			||||||
 | 
						data := genTestDataSearchResultsData()
 | 
				
			||||||
	searchResults := []*milvuspb.SearchResults{
 | 
						searchResults := []*milvuspb.SearchResults{
 | 
				
			||||||
		{Results: searchResultData1},
 | 
							{Results: data[0]},
 | 
				
			||||||
		{Results: searchResultData2},
 | 
							{Results: data[1]},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	nq := int64(1)
 | 
						nq := int64(1)
 | 
				
			||||||
| 
						 | 
					@ -128,6 +132,16 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() {
 | 
				
			||||||
 | 
						data := genTestDataSearchResultsData()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							results, err := reduceSearchResultDataNoGroupBy(context.Background(), []*schemapb.SearchResultData{data[0]}, 0, 0, "L2", schemapb.DataType_Int64, 0)
 | 
				
			||||||
 | 
							struts.NoError(err)
 | 
				
			||||||
 | 
							struts.Equal([]string{"7", "5", "4", "2", "3", "6", "1", "9", "8"}, results.Results.GetIds().GetStrId().Data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSearchReduceUtilTestSuite(t *testing.T) {
 | 
					func TestSearchReduceUtilTestSuite(t *testing.T) {
 | 
				
			||||||
	suite.Run(t, new(SearchReduceUtilTestSuite))
 | 
						suite.Run(t, new(SearchReduceUtilTestSuite))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -592,8 +592,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
				
			||||||
	log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
 | 
						log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		ret     = &milvuspb.QueryResults{}
 | 
							ret     = &milvuspb.QueryResults{}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		skipDupCnt int64
 | 
					 | 
				
			||||||
		loopEnd int
 | 
							loopEnd int
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -611,7 +609,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
				
			||||||
		return ret, nil
 | 
							return ret, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	idSet := make(map[interface{}]struct{})
 | 
					 | 
				
			||||||
	cursors := make([]int64, len(validRetrieveResults))
 | 
						cursors := make([]int64, len(validRetrieveResults))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if queryParams != nil && queryParams.limit != typeutil.Unlimited {
 | 
						if queryParams != nil && queryParams.limit != typeutil.Unlimited {
 | 
				
			||||||
| 
						 | 
					@ -636,21 +633,12 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
				
			||||||
	ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
 | 
						ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
 | 
				
			||||||
	var retSize int64
 | 
						var retSize int64
 | 
				
			||||||
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
						maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
				
			||||||
	for j := 0; j < loopEnd; {
 | 
						for j := 0; j < loopEnd; j++ {
 | 
				
			||||||
		sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
 | 
							sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
 | 
				
			||||||
		if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
 | 
							if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
 | 
					 | 
				
			||||||
		if _, ok := idSet[pk]; !ok {
 | 
					 | 
				
			||||||
		retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
 | 
							retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
 | 
				
			||||||
			idSet[pk] = struct{}{}
 | 
					 | 
				
			||||||
			j++
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			// primary keys duplicate
 | 
					 | 
				
			||||||
			skipDupCnt++
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// limit retrieve result to avoid oom
 | 
							// limit retrieve result to avoid oom
 | 
				
			||||||
		if retSize > maxOutputSize {
 | 
							if retSize > maxOutputSize {
 | 
				
			||||||
| 
						 | 
					@ -660,10 +648,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
				
			||||||
		cursors[sel]++
 | 
							cursors[sel]++
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if skipDupCnt > 0 {
 | 
					 | 
				
			||||||
		log.Ctx(ctx).Debug("skip duplicated query result while reducing QueryResults", zap.Int64("count", skipDupCnt))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return ret, nil
 | 
						return ret, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -461,34 +461,6 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
				
			||||||
		fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
 | 
							fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
 | 
				
			||||||
		fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
 | 
							fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		t.Run("test skip dupPK 2", func(t *testing.T) {
 | 
					 | 
				
			||||||
			result1 := &internalpb.RetrieveResults{
 | 
					 | 
				
			||||||
				Ids: &schemapb.IDs{
 | 
					 | 
				
			||||||
					IdField: &schemapb.IDs_IntId{
 | 
					 | 
				
			||||||
						IntId: &schemapb.LongArray{
 | 
					 | 
				
			||||||
							Data: []int64{0, 1},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				FieldsData: fieldDataArray1,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			result2 := &internalpb.RetrieveResults{
 | 
					 | 
				
			||||||
				Ids: &schemapb.IDs{
 | 
					 | 
				
			||||||
					IdField: &schemapb.IDs_IntId{
 | 
					 | 
				
			||||||
						IntId: &schemapb.LongArray{
 | 
					 | 
				
			||||||
							Data: []int64{0, 1},
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				FieldsData: fieldDataArray2,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, &queryParams{limit: 2})
 | 
					 | 
				
			||||||
			assert.NoError(t, err)
 | 
					 | 
				
			||||||
			assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
					 | 
				
			||||||
			assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
					 | 
				
			||||||
			assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		t.Run("test nil results", func(t *testing.T) {
 | 
							t.Run("test nil results", func(t *testing.T) {
 | 
				
			||||||
			ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
 | 
								ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,71 @@
 | 
				
			||||||
 | 
					// Licensed to the LF AI & Data foundation under one
 | 
				
			||||||
 | 
					// or more contributor license agreements. See the NOTICE file
 | 
				
			||||||
 | 
					// distributed with this work for additional information
 | 
				
			||||||
 | 
					// regarding copyright ownership. The ASF licenses this file
 | 
				
			||||||
 | 
					// to you under the Apache License, Version 2.0 (the
 | 
				
			||||||
 | 
					// "License"); you may not use this file except in compliance
 | 
				
			||||||
 | 
					// with the License. You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package typeutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GenericSchemaSlice[T any] interface {
 | 
				
			||||||
 | 
						Append(T)
 | 
				
			||||||
 | 
						Get(int) T
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type int64PkSlice struct {
 | 
				
			||||||
 | 
						data *schemapb.IDs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *int64PkSlice) Append(v int64) {
 | 
				
			||||||
 | 
						s.data.GetIntId().Data = append(s.data.GetIntId().Data, v)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *int64PkSlice) Get(i int) int64 {
 | 
				
			||||||
 | 
						return s.data.GetIntId().Data[i]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type stringPkSlice struct {
 | 
				
			||||||
 | 
						data *schemapb.IDs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *stringPkSlice) Append(v string) {
 | 
				
			||||||
 | 
						s.data.GetStrId().Data = append(s.data.GetStrId().Data, v)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *stringPkSlice) Get(i int) string {
 | 
				
			||||||
 | 
						return s.data.GetStrId().Data[i]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewInt64PkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[int64] {
 | 
				
			||||||
 | 
						return &int64PkSlice{
 | 
				
			||||||
 | 
							data: data,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewStringPkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[string] {
 | 
				
			||||||
 | 
						return &stringPkSlice{
 | 
				
			||||||
 | 
							data: data,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func CopyPk(dst *schemapb.IDs, src *schemapb.IDs, offset int) {
 | 
				
			||||||
 | 
						switch dst.GetIdField().(type) {
 | 
				
			||||||
 | 
						case *schemapb.IDs_IntId:
 | 
				
			||||||
 | 
							v := src.GetIntId().Data[offset]
 | 
				
			||||||
 | 
							dst.GetIntId().Data = append(dst.GetIntId().Data, v)
 | 
				
			||||||
 | 
						case *schemapb.IDs_StrId:
 | 
				
			||||||
 | 
							v := src.GetStrId().Data[offset]
 | 
				
			||||||
 | 
							dst.GetStrId().Data = append(dst.GetStrId().Data, v)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,176 @@
 | 
				
			||||||
 | 
					// Licensed to the LF AI & Data foundation under one
 | 
				
			||||||
 | 
					// or more contributor license agreements. See the NOTICE file
 | 
				
			||||||
 | 
					// distributed with this work for additional information
 | 
				
			||||||
 | 
					// regarding copyright ownership. The ASF licenses this file
 | 
				
			||||||
 | 
					// to you under the Apache License, Version 2.0 (the
 | 
				
			||||||
 | 
					// "License"); you may not use this file except in compliance
 | 
				
			||||||
 | 
					// with the License. You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package typeutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPKSlice(t *testing.T) {
 | 
				
			||||||
 | 
						data1 := &schemapb.IDs{
 | 
				
			||||||
 | 
							IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
								IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
									Data: []int64{1, 2, 3},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ints := NewInt64PkSchemaSlice(data1)
 | 
				
			||||||
 | 
						assert.Equal(t, int64(1), ints.Get(0))
 | 
				
			||||||
 | 
						assert.Equal(t, int64(2), ints.Get(1))
 | 
				
			||||||
 | 
						assert.Equal(t, int64(3), ints.Get(2))
 | 
				
			||||||
 | 
						ints.Append(4)
 | 
				
			||||||
 | 
						assert.Equal(t, int64(4), ints.Get(3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data2 := &schemapb.IDs{
 | 
				
			||||||
 | 
							IdField: &schemapb.IDs_StrId{
 | 
				
			||||||
 | 
								StrId: &schemapb.StringArray{
 | 
				
			||||||
 | 
									Data: []string{"1", "2", "3"},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						strs := NewStringPkSchemaSlice(data2)
 | 
				
			||||||
 | 
						assert.Equal(t, "1", strs.Get(0))
 | 
				
			||||||
 | 
						assert.Equal(t, "2", strs.Get(1))
 | 
				
			||||||
 | 
						assert.Equal(t, "3", strs.Get(2))
 | 
				
			||||||
 | 
						strs.Append("4")
 | 
				
			||||||
 | 
						assert.Equal(t, "4", strs.Get(3))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCopyPk(t *testing.T) {
 | 
				
			||||||
 | 
						type args struct {
 | 
				
			||||||
 | 
							dst      *schemapb.IDs
 | 
				
			||||||
 | 
							src      *schemapb.IDs
 | 
				
			||||||
 | 
							offset   int
 | 
				
			||||||
 | 
							dstAfter *schemapb.IDs
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name string
 | 
				
			||||||
 | 
							args args
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "ints",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dst: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
											IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
												Data: []int64{1, 2, 3},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									src: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
											IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
												Data: []int64{1, 2, 3},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									offset: 0,
 | 
				
			||||||
 | 
									dstAfter: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
											IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
												Data: []int64{1, 2, 3, 1},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "strs",
 | 
				
			||||||
 | 
								args: args{
 | 
				
			||||||
 | 
									dst: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_StrId{
 | 
				
			||||||
 | 
											StrId: &schemapb.StringArray{
 | 
				
			||||||
 | 
												Data: []string{"1", "2", "3"},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									src: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_StrId{
 | 
				
			||||||
 | 
											StrId: &schemapb.StringArray{
 | 
				
			||||||
 | 
												Data: []string{"1", "2", "3"},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									offset: 0,
 | 
				
			||||||
 | 
									dstAfter: &schemapb.IDs{
 | 
				
			||||||
 | 
										IdField: &schemapb.IDs_StrId{
 | 
				
			||||||
 | 
											StrId: &schemapb.StringArray{
 | 
				
			||||||
 | 
												Data: []string{"1", "2", "3", "1"},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								CopyPk(tt.args.dst, tt.args.src, tt.args.offset)
 | 
				
			||||||
 | 
								assert.Equal(t, tt.args.dst, tt.args.dstAfter)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkCopyPK(b *testing.B) {
 | 
				
			||||||
 | 
						internal := make([]int64, 1000)
 | 
				
			||||||
 | 
						for i := 0; i < 1000; i++ {
 | 
				
			||||||
 | 
							internal[i] = int64(i)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						src := &schemapb.IDs{
 | 
				
			||||||
 | 
							IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
								IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
									Data: internal,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b.ResetTimer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b.Run("Typed", func(b *testing.B) {
 | 
				
			||||||
 | 
							dst := &schemapb.IDs{
 | 
				
			||||||
 | 
								IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
									IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
										Data: make([]int64, 0, 1000),
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
								for j := 0; j < GetSizeOfIDs(src); j++ {
 | 
				
			||||||
 | 
									CopyPk(dst, src, j)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						b.Run("Any", func(b *testing.B) {
 | 
				
			||||||
 | 
							dst := &schemapb.IDs{
 | 
				
			||||||
 | 
								IdField: &schemapb.IDs_IntId{
 | 
				
			||||||
 | 
									IntId: &schemapb.LongArray{
 | 
				
			||||||
 | 
										Data: make([]int64, 0, 1000),
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
								for j := 0; j < GetSizeOfIDs(src); j++ {
 | 
				
			||||||
 | 
									pk := GetPK(src, int64(j))
 | 
				
			||||||
 | 
									AppendPKs(dst, pk)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue