mirror of https://github.com/milvus-io/milvus.git
related: #30828 Signed-off-by: MrPresent-Han <chun.han@zilliz.com>pull/31079/head
parent
fd17a5f050
commit
d0eeea4b44
|
@ -0,0 +1,381 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type reduceSearchResultInfo struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
nq int64
|
||||
topK int64
|
||||
metricType string
|
||||
pkType schemapb.DataType
|
||||
offset int64
|
||||
queryInfo *planpb.QueryInfo
|
||||
}
|
||||
|
||||
func NewReduceSearchResultInfo(
|
||||
subSearchResultData []*schemapb.SearchResultData,
|
||||
nq int64,
|
||||
topK int64,
|
||||
metricType string,
|
||||
pkType schemapb.DataType,
|
||||
offset int64,
|
||||
queryInfo *planpb.QueryInfo,
|
||||
) *reduceSearchResultInfo {
|
||||
return &reduceSearchResultInfo{
|
||||
subSearchResultData: subSearchResultData,
|
||||
nq: nq,
|
||||
topK: topK,
|
||||
metricType: metricType,
|
||||
pkType: pkType,
|
||||
offset: offset,
|
||||
queryInfo: queryInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) {
|
||||
if reduceInfo.queryInfo.GroupByFieldId > 0 {
|
||||
return reduceSearchResultDataWithGroupBy(ctx,
|
||||
reduceInfo.subSearchResultData,
|
||||
reduceInfo.nq,
|
||||
reduceInfo.topK,
|
||||
reduceInfo.metricType,
|
||||
reduceInfo.pkType,
|
||||
reduceInfo.offset)
|
||||
}
|
||||
return reduceSearchResultDataNoGroupBy(ctx,
|
||||
reduceInfo.subSearchResultData,
|
||||
reduceInfo.nq,
|
||||
reduceInfo.topK,
|
||||
reduceInfo.metricType,
|
||||
reduceInfo.pkType,
|
||||
reduceInfo.offset)
|
||||
}
|
||||
|
||||
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
limit := topk - offset
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData",
|
||||
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0, limit),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
groupByValSet = make(map[interface{}]struct{})
|
||||
)
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.Scores[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
_, groupByValExist := groupByValSet[groupByVal]
|
||||
if !groupByValExist {
|
||||
groupByValSet[groupByVal] = struct{}{}
|
||||
if int64(len(groupByValSet)) <= offset {
|
||||
continue
|
||||
// skip offset groups
|
||||
}
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
j++
|
||||
} else {
|
||||
// skip entity with same groupby
|
||||
skipDupCnt++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Ctx(ctx).Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
limit := topk - offset
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData",
|
||||
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0, limit),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
)
|
||||
|
||||
// skip offset results
|
||||
for k := int64(0); k < offset; k++ {
|
||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
|
||||
// remove duplicatessds
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
|
@ -118,6 +118,7 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
|
|||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.queryInfo = queryInfo
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
|
|
|
@ -70,6 +70,7 @@ type searchTask struct {
|
|||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
queryChannelsTs map[string]Timestamp
|
||||
queryInfo *planpb.QueryInfo
|
||||
}
|
||||
|
||||
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
|
@ -443,7 +444,8 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
t.result, err = reduceSearchResultData(ctx, validSearchResults, Nq, Topk, MetricType, primaryFieldSchema.DataType, t.offset)
|
||||
t.result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, Nq, Topk,
|
||||
MetricType, primaryFieldSchema.DataType, t.offset, t.queryInfo))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return err
|
||||
|
@ -751,173 +753,6 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
|
|||
return subSearchIdx, resultDataIdx
|
||||
}
|
||||
|
||||
func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
limit := topk - offset
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData",
|
||||
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0, limit),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
groupByValSet = make(map[interface{}]struct{})
|
||||
)
|
||||
|
||||
// skip offset results
|
||||
for k := int64(0); k < offset; k++ {
|
||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.Scores[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
groupByValExist := false
|
||||
if groupByVal != nil {
|
||||
_, groupByValExist = groupByValSet[groupByVal]
|
||||
}
|
||||
if !groupByValExist {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
if groupByVal != nil {
|
||||
groupByValSet[groupByVal] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type rangeSearchParams struct {
|
||||
radius float64
|
||||
rangeFilter float64
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
|
@ -1526,9 +1527,13 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
results = append(results, r)
|
||||
}
|
||||
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
|
||||
reduced, err := reduceSearchResult(context.TODO(),
|
||||
NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1577,10 +1582,10 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
[]int64{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range lessThanLimitTests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk,
|
||||
metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1604,7 +1609,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
results = append(results, r)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, 0)
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(
|
||||
results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
|
@ -1630,8 +1640,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
|
||||
results = append(results, r)
|
||||
}
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_VarChar, 0)
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results,
|
||||
nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||
|
@ -1700,8 +1714,11 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topK, metric.L2, schemapb.DataType_Int64, 0)
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
||||
schemapb.DataType_Int64, 0, queryInfo))
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
|
@ -1713,6 +1730,63 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
|
||||
var (
|
||||
nq int64 = 1
|
||||
limit int64 = 5
|
||||
offset int64 = 5
|
||||
)
|
||||
ids := [][]int64{
|
||||
{1, 3, 5, 7, 9},
|
||||
{2, 4, 6, 8, 10},
|
||||
}
|
||||
scores := [][]float32{
|
||||
{10, 8, 6, 4, 2},
|
||||
{9, 7, 5, 3, 1},
|
||||
}
|
||||
groupByValuesArr := [][]int64{
|
||||
{1, 3, 5, 7, 9},
|
||||
{2, 4, 6, 8, 10},
|
||||
}
|
||||
expectedIDs := []int64{6, 7, 8, 9, 10}
|
||||
expectedScores := []float32{-5, -4, -3, -2, -1}
|
||||
expectedGroupByValues := []int64{6, 7, 8, 9, 10}
|
||||
|
||||
var results []*schemapb.SearchResultData
|
||||
for j := range ids {
|
||||
result := getSearchResultData(nq, limit+offset)
|
||||
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
|
||||
result.Scores = scores[j]
|
||||
result.Topks = []int64{limit}
|
||||
result.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: groupByValuesArr[j],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
|
||||
schemapb.DataType_Int64, offset, queryInfo))
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
assert.EqualValues(t, expectedIDs, resultIDs)
|
||||
assert.EqualValues(t, expectedScores, resultScores)
|
||||
assert.EqualValues(t, expectedGroupByValues, resultGroupByValues)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
|
@ -2367,7 +2441,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
qt.resultBuf.Insert(&internalpb.SearchResults{
|
||||
SlicedBlob: bytes,
|
||||
})
|
||||
|
||||
qt.queryInfo = &planpb.QueryInfo{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
err = qt.PostExecute(ctx)
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
|
|
Loading…
Reference in New Issue