mirror of https://github.com/milvus-io/milvus.git
enhance: simplify reduction on single search result (#36334)
See: #36122 --------- Signed-off-by: Ted Xu <ted.xu@zilliz.com>pull/36278/head
parent
89397d1e66
commit
363004fd44
|
|
@ -218,55 +218,49 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||||
totalResCount += subSearchNqOffset[i][nq-1]
|
totalResCount += subSearchNqOffset[i][nq-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
if subSearchNum == 1 && offset == 0 {
|
||||||
skipDupCnt int64
|
ret.Results = subSearchResultData[0]
|
||||||
realTopK int64 = -1
|
} else {
|
||||||
)
|
var realTopK int64 = -1
|
||||||
|
var retSize int64
|
||||||
|
|
||||||
var retSize int64
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
// reducing nq * topk results
|
||||||
|
for i := int64(0); i < nq; i++ {
|
||||||
|
var (
|
||||||
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||||
|
// sum(cursors) == j
|
||||||
|
cursors = make([]int64, subSearchNum)
|
||||||
|
|
||||||
// reducing nq * topk results
|
j int64
|
||||||
for i := int64(0); i < nq; i++ {
|
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||||
var (
|
skipOffsetMap = make(map[interface{}]bool)
|
||||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
groupByValList = make([]interface{}, limit)
|
||||||
// sum(cursors) == j
|
groupByValIdx = 0
|
||||||
cursors = make([]int64, subSearchNum)
|
)
|
||||||
|
|
||||||
j int64
|
for j = 0; j < groupBound; {
|
||||||
pkSet = make(map[interface{}]struct{})
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
if subSearchIdx == -1 {
|
||||||
skipOffsetMap = make(map[interface{}]bool)
|
break
|
||||||
groupByValList = make([]interface{}, limit)
|
}
|
||||||
groupByValIdx = 0
|
subSearchRes := subSearchResultData[subSearchIdx]
|
||||||
)
|
|
||||||
|
|
||||||
for j = 0; j < groupBound; {
|
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
score := subSearchRes.GetScores()[resultDataIdx]
|
||||||
if subSearchIdx == -1 {
|
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||||
break
|
if groupByVal == nil {
|
||||||
}
|
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||||
subSearchRes := subSearchResultData[subSearchIdx]
|
"there must be sth wrong on queryNode side")
|
||||||
|
}
|
||||||
|
|
||||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
|
||||||
score := subSearchRes.GetScores()[resultDataIdx]
|
|
||||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
|
||||||
if groupByVal == nil {
|
|
||||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
|
||||||
"there must be sth wrong on queryNode side")
|
|
||||||
}
|
|
||||||
// remove duplicates
|
|
||||||
if _, ok := pkSet[id]; !ok {
|
|
||||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||||
skipOffsetMap[groupByVal] = true
|
skipOffsetMap[groupByVal] = true
|
||||||
// the first offset's group will be ignored
|
// the first offset's group will be ignored
|
||||||
skipDupCnt++
|
|
||||||
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
||||||
// skip when groupbyMap has been full and found new groupByVal
|
// skip when groupbyMap has been full and found new groupByVal
|
||||||
skipDupCnt++
|
|
||||||
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
||||||
// skip when target group has been full
|
// skip when target group has been full
|
||||||
skipDupCnt++
|
|
||||||
} else {
|
} else {
|
||||||
if len(groupByValMap[groupByVal]) == 0 {
|
if len(groupByValMap[groupByVal]) == 0 {
|
||||||
groupByValList[groupByValIdx] = groupByVal
|
groupByValList[groupByValIdx] = groupByVal
|
||||||
|
|
@ -276,55 +270,43 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||||
subSearchIdx: subSearchIdx,
|
subSearchIdx: subSearchIdx,
|
||||||
resultIdx: resultDataIdx, id: id, score: score,
|
resultIdx: resultDataIdx, id: id, score: score,
|
||||||
})
|
})
|
||||||
pkSet[id] = struct{}{}
|
|
||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
skipDupCnt++
|
cursors[subSearchIdx]++
|
||||||
}
|
}
|
||||||
|
|
||||||
cursors[subSearchIdx]++
|
// assemble all eligible values in group
|
||||||
}
|
// values in groupByValList is sorted by the highest score in each group
|
||||||
|
for _, groupVal := range groupByValList {
|
||||||
// assemble all eligible values in group
|
if groupVal != nil {
|
||||||
// values in groupByValList is sorted by the highest score in each group
|
groupEntities := groupByValMap[groupVal]
|
||||||
for _, groupVal := range groupByValList {
|
for _, groupEntity := range groupEntities {
|
||||||
if groupVal != nil {
|
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||||
groupEntities := groupByValMap[groupVal]
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||||
for _, groupEntity := range groupEntities {
|
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
return ret, err
|
||||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
}
|
||||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if realTopK != -1 && realTopK != j {
|
if realTopK != -1 && realTopK != j {
|
||||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||||
}
|
}
|
||||||
realTopK = j
|
realTopK = j
|
||||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||||
|
|
||||||
// limit search result to avoid oom
|
// limit search result to avoid oom
|
||||||
if retSize > maxOutputSize {
|
if retSize > maxOutputSize {
|
||||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||||
}
|
}
|
||||||
log.Ctx(ctx).Debug("skip duplicated search result when doing group by", zap.Int64("count", skipDupCnt))
|
|
||||||
|
|
||||||
if float64(skipDupCnt) >= float64(totalResCount)*0.3 {
|
|
||||||
log.Warn("GroupBy reduce skipped too many results, "+
|
|
||||||
"this may influence the final result seriously",
|
|
||||||
zap.Int64("skipDupCnt", skipDupCnt),
|
|
||||||
zap.Int64("groupBound", groupBound))
|
|
||||||
}
|
|
||||||
|
|
||||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
||||||
if !metric.PositivelyRelated(metricType) {
|
if !metric.PositivelyRelated(metricType) {
|
||||||
for k := range ret.Results.Scores {
|
for k := range ret.Results.Scores {
|
||||||
ret.Results.Scores[k] *= -1
|
ret.Results.Scores[k] *= -1
|
||||||
|
|
@ -370,91 +352,79 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||||
ret.GetResults().AllSearchCount = allSearchCount
|
ret.GetResults().AllSearchCount = allSearchCount
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
subSearchNum := len(subSearchResultData)
|
||||||
subSearchNum = len(subSearchResultData)
|
if subSearchNum == 1 && offset == 0 {
|
||||||
|
// sorting is not needed if there is only one shard and no offset, assigning the result directly.
|
||||||
|
// we still need to adjust the scores later.
|
||||||
|
ret.Results = subSearchResultData[0]
|
||||||
|
// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
|
||||||
|
topks := subSearchResultData[0].Topks
|
||||||
|
if len(topks) > 0 {
|
||||||
|
ret.Results.TopK = topks[len(topks)-1]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var realTopK int64 = -1
|
||||||
|
var retSize int64
|
||||||
|
|
||||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
subSearchNqOffset := make([][]int64, subSearchNum)
|
||||||
)
|
for i := 0; i < subSearchNum; i++ {
|
||||||
for i := 0; i < subSearchNum; i++ {
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
for j := int64(1); j < nq; j++ {
|
||||||
for j := int64(1); j < nq; j++ {
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
}
|
||||||
}
|
}
|
||||||
}
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
|
// reducing nq * topk results
|
||||||
|
for i := int64(0); i < nq; i++ {
|
||||||
|
var (
|
||||||
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||||
|
// sum(cursors) == j
|
||||||
|
cursors = make([]int64, subSearchNum)
|
||||||
|
j int64
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
// skip offset results
|
||||||
skipDupCnt int64
|
for k := int64(0); k < offset; k++ {
|
||||||
realTopK int64 = -1
|
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||||
)
|
if subSearchIdx == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
var retSize int64
|
cursors[subSearchIdx]++
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
||||||
|
|
||||||
// reducing nq * topk results
|
|
||||||
for i := int64(0); i < nq; i++ {
|
|
||||||
var (
|
|
||||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
|
||||||
// sum(cursors) == j
|
|
||||||
cursors = make([]int64, subSearchNum)
|
|
||||||
|
|
||||||
j int64
|
|
||||||
idSet = make(map[interface{}]struct{}, limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
// skip offset results
|
|
||||||
for k := int64(0); k < offset; k++ {
|
|
||||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
|
||||||
if subSearchIdx == -1 {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cursors[subSearchIdx]++
|
// keep limit results
|
||||||
}
|
for j = 0; j < limit; j++ {
|
||||||
|
// From all the sub-query result sets of the i-th query vector,
|
||||||
|
// find the sub-query result set index of the score j-th data,
|
||||||
|
// and the index of the data in schemapb.SearchResultData
|
||||||
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||||
|
if subSearchIdx == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||||
|
|
||||||
// keep limit results
|
|
||||||
for j = 0; j < limit; {
|
|
||||||
// From all the sub-query result sets of the i-th query vector,
|
|
||||||
// find the sub-query result set index of the score j-th data,
|
|
||||||
// and the index of the data in schemapb.SearchResultData
|
|
||||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
|
||||||
if subSearchIdx == -1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
|
||||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
|
||||||
|
|
||||||
// remove duplicatessds
|
|
||||||
if _, ok := idSet[id]; !ok {
|
|
||||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
||||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
idSet[id] = struct{}{}
|
cursors[subSearchIdx]++
|
||||||
j++
|
|
||||||
} else {
|
|
||||||
// skip entity with same id
|
|
||||||
skipDupCnt++
|
|
||||||
}
|
}
|
||||||
cursors[subSearchIdx]++
|
if realTopK != -1 && realTopK != j {
|
||||||
}
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||||
if realTopK != -1 && realTopK != j {
|
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
}
|
||||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
realTopK = j
|
||||||
}
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||||
realTopK = j
|
|
||||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
||||||
|
|
||||||
// limit search result to avoid oom
|
// limit search result to avoid oom
|
||||||
if retSize > maxOutputSize {
|
if retSize > maxOutputSize {
|
||||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
|
||||||
|
|
||||||
if skipDupCnt > 0 {
|
|
||||||
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
||||||
if !metric.PositivelyRelated(metricType) {
|
if !metric.PositivelyRelated(metricType) {
|
||||||
for k := range ret.Results.Scores {
|
for k := range ret.Results.Scores {
|
||||||
ret.Results.Scores[k] *= -1
|
ret.Results.Scores[k] *= -1
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ type SearchReduceUtilTestSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
}
|
}
|
||||||
|
|
||||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
func genTestDataSearchResultsData() []*schemapb.SearchResultData {
|
||||||
var searchResultData1 *schemapb.SearchResultData
|
var searchResultData1 *schemapb.SearchResultData
|
||||||
var searchResultData2 *schemapb.SearchResultData
|
var searchResultData2 *schemapb.SearchResultData
|
||||||
|
|
||||||
|
|
@ -49,10 +49,14 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||||
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return []*schemapb.SearchResultData{searchResultData1, searchResultData2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||||
|
data := genTestDataSearchResultsData()
|
||||||
searchResults := []*milvuspb.SearchResults{
|
searchResults := []*milvuspb.SearchResults{
|
||||||
{Results: searchResultData1},
|
{Results: data[0]},
|
||||||
{Results: searchResultData2},
|
{Results: data[1]},
|
||||||
}
|
}
|
||||||
|
|
||||||
nq := int64(1)
|
nq := int64(1)
|
||||||
|
|
@ -128,6 +132,16 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() {
|
||||||
|
data := genTestDataSearchResultsData()
|
||||||
|
|
||||||
|
{
|
||||||
|
results, err := reduceSearchResultDataNoGroupBy(context.Background(), []*schemapb.SearchResultData{data[0]}, 0, 0, "L2", schemapb.DataType_Int64, 0)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{"7", "5", "4", "2", "3", "6", "1", "9", "8"}, results.Results.GetIds().GetStrId().Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSearchReduceUtilTestSuite(t *testing.T) {
|
func TestSearchReduceUtilTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(SearchReduceUtilTestSuite))
|
suite.Run(t, new(SearchReduceUtilTestSuite))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -591,10 +591,8 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
|
||||||
func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) {
|
func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) {
|
||||||
log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
|
log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
|
||||||
var (
|
var (
|
||||||
ret = &milvuspb.QueryResults{}
|
ret = &milvuspb.QueryResults{}
|
||||||
|
loopEnd int
|
||||||
skipDupCnt int64
|
|
||||||
loopEnd int
|
|
||||||
)
|
)
|
||||||
|
|
||||||
validRetrieveResults := []*internalpb.RetrieveResults{}
|
validRetrieveResults := []*internalpb.RetrieveResults{}
|
||||||
|
|
@ -611,7 +609,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
idSet := make(map[interface{}]struct{})
|
|
||||||
cursors := make([]int64, len(validRetrieveResults))
|
cursors := make([]int64, len(validRetrieveResults))
|
||||||
|
|
||||||
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
|
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
|
||||||
|
|
@ -636,21 +633,12 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
||||||
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
|
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
|
||||||
var retSize int64
|
var retSize int64
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
for j := 0; j < loopEnd; {
|
for j := 0; j < loopEnd; j++ {
|
||||||
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
|
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
|
||||||
if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
|
if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
|
||||||
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
|
|
||||||
if _, ok := idSet[pk]; !ok {
|
|
||||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
|
|
||||||
idSet[pk] = struct{}{}
|
|
||||||
j++
|
|
||||||
} else {
|
|
||||||
// primary keys duplicate
|
|
||||||
skipDupCnt++
|
|
||||||
}
|
|
||||||
|
|
||||||
// limit retrieve result to avoid oom
|
// limit retrieve result to avoid oom
|
||||||
if retSize > maxOutputSize {
|
if retSize > maxOutputSize {
|
||||||
|
|
@ -660,10 +648,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
||||||
cursors[sel]++
|
cursors[sel]++
|
||||||
}
|
}
|
||||||
|
|
||||||
if skipDupCnt > 0 {
|
|
||||||
log.Ctx(ctx).Debug("skip duplicated query result while reducing QueryResults", zap.Int64("count", skipDupCnt))
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -461,34 +461,6 @@ func TestTaskQuery_functions(t *testing.T) {
|
||||||
fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
||||||
fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
||||||
|
|
||||||
t.Run("test skip dupPK 2", func(t *testing.T) {
|
|
||||||
result1 := &internalpb.RetrieveResults{
|
|
||||||
Ids: &schemapb.IDs{
|
|
||||||
IdField: &schemapb.IDs_IntId{
|
|
||||||
IntId: &schemapb.LongArray{
|
|
||||||
Data: []int64{0, 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FieldsData: fieldDataArray1,
|
|
||||||
}
|
|
||||||
result2 := &internalpb.RetrieveResults{
|
|
||||||
Ids: &schemapb.IDs{
|
|
||||||
IdField: &schemapb.IDs_IntId{
|
|
||||||
IntId: &schemapb.LongArray{
|
|
||||||
Data: []int64{0, 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FieldsData: fieldDataArray2,
|
|
||||||
}
|
|
||||||
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, &queryParams{limit: 2})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
||||||
assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
||||||
assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test nil results", func(t *testing.T) {
|
t.Run("test nil results", func(t *testing.T) {
|
||||||
ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
|
ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package typeutil
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
|
type GenericSchemaSlice[T any] interface {
|
||||||
|
Append(T)
|
||||||
|
Get(int) T
|
||||||
|
}
|
||||||
|
|
||||||
|
type int64PkSlice struct {
|
||||||
|
data *schemapb.IDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *int64PkSlice) Append(v int64) {
|
||||||
|
s.data.GetIntId().Data = append(s.data.GetIntId().Data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *int64PkSlice) Get(i int) int64 {
|
||||||
|
return s.data.GetIntId().Data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringPkSlice struct {
|
||||||
|
data *schemapb.IDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stringPkSlice) Append(v string) {
|
||||||
|
s.data.GetStrId().Data = append(s.data.GetStrId().Data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stringPkSlice) Get(i int) string {
|
||||||
|
return s.data.GetStrId().Data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInt64PkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[int64] {
|
||||||
|
return &int64PkSlice{
|
||||||
|
data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStringPkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[string] {
|
||||||
|
return &stringPkSlice{
|
||||||
|
data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyPk(dst *schemapb.IDs, src *schemapb.IDs, offset int) {
|
||||||
|
switch dst.GetIdField().(type) {
|
||||||
|
case *schemapb.IDs_IntId:
|
||||||
|
v := src.GetIntId().Data[offset]
|
||||||
|
dst.GetIntId().Data = append(dst.GetIntId().Data, v)
|
||||||
|
case *schemapb.IDs_StrId:
|
||||||
|
v := src.GetStrId().Data[offset]
|
||||||
|
dst.GetStrId().Data = append(dst.GetStrId().Data, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package typeutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPKSlice(t *testing.T) {
|
||||||
|
data1 := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ints := NewInt64PkSchemaSlice(data1)
|
||||||
|
assert.Equal(t, int64(1), ints.Get(0))
|
||||||
|
assert.Equal(t, int64(2), ints.Get(1))
|
||||||
|
assert.Equal(t, int64(3), ints.Get(2))
|
||||||
|
ints.Append(4)
|
||||||
|
assert.Equal(t, int64(4), ints.Get(3))
|
||||||
|
|
||||||
|
data2 := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"1", "2", "3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
strs := NewStringPkSchemaSlice(data2)
|
||||||
|
assert.Equal(t, "1", strs.Get(0))
|
||||||
|
assert.Equal(t, "2", strs.Get(1))
|
||||||
|
assert.Equal(t, "3", strs.Get(2))
|
||||||
|
strs.Append("4")
|
||||||
|
assert.Equal(t, "4", strs.Get(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyPk(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
dst *schemapb.IDs
|
||||||
|
src *schemapb.IDs
|
||||||
|
offset int
|
||||||
|
dstAfter *schemapb.IDs
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ints",
|
||||||
|
args: args{
|
||||||
|
dst: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
src: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
offset: 0,
|
||||||
|
dstAfter: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3, 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strs",
|
||||||
|
args: args{
|
||||||
|
dst: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"1", "2", "3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
src: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"1", "2", "3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
offset: 0,
|
||||||
|
dstAfter: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"1", "2", "3", "1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
CopyPk(tt.args.dst, tt.args.src, tt.args.offset)
|
||||||
|
assert.Equal(t, tt.args.dst, tt.args.dstAfter)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCopyPK(b *testing.B) {
|
||||||
|
internal := make([]int64, 1000)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
internal[i] = int64(i)
|
||||||
|
}
|
||||||
|
src := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: internal,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.Run("Typed", func(b *testing.B) {
|
||||||
|
dst := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: make([]int64, 0, 1000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < GetSizeOfIDs(src); j++ {
|
||||||
|
CopyPk(dst, src, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run("Any", func(b *testing.B) {
|
||||||
|
dst := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: make([]int64, 0, 1000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < GetSizeOfIDs(src); j++ {
|
||||||
|
pk := GetPK(src, int64(j))
|
||||||
|
AppendPKs(dst, pk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue