mirror of https://github.com/milvus-io/milvus.git
feat: supporing hybrid search group_by (#35982)
related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/36113/head
parent
62f4a6a112
commit
e480b103bd
|
@ -170,7 +170,8 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
|
|||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "HYBRID",
|
||||
}},
|
||||
},
|
||||
},
|
||||
Timestamp: 0,
|
||||
IsAutoIndex: true,
|
||||
UserIndexParams: userIndexParams,
|
||||
|
@ -205,7 +206,6 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
|
|||
assert.Equal(t, newIndexParams[0].Key, common.IndexTypeKey)
|
||||
assert.Equal(t, newIndexParams[0].Value, "INVERTED")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestMeta_CanCreateIndex(t *testing.T) {
|
||||
|
|
|
@ -769,7 +769,7 @@ func (s *taskSchedulerSuite) scheduler(handler Handler) {
|
|||
return nil
|
||||
})
|
||||
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil)
|
||||
//catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
|
||||
// catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
in := mocks.NewMockIndexNodeClient(s.T())
|
||||
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||
|
|
|
@ -63,7 +63,7 @@ func mergeSortMultipleSegments(ctx context.Context,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
//SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
|
||||
// SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
|
||||
segmentReaders := make([]*SegmentDeserializeReader, len(binlogs))
|
||||
for i, s := range binlogs {
|
||||
var binlogBatchCount int
|
||||
|
|
|
@ -119,7 +119,6 @@ func (s *PriorityQueueSuite) PriorityQueueMergeSort() {
|
|||
heap.Push(&pq, next)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNewPriorityQueueSuite(t *testing.T) {
|
||||
|
|
|
@ -94,11 +94,8 @@ message SubSearchRequest {
|
|||
int64 topk = 7;
|
||||
int64 offset = 8;
|
||||
string metricType = 9;
|
||||
}
|
||||
|
||||
message ExtraSearchParam {
|
||||
int64 group_by_field_id = 1;
|
||||
int64 group_size = 2;
|
||||
int64 group_by_field_id = 10;
|
||||
int64 group_size = 11;
|
||||
}
|
||||
|
||||
message SearchRequest {
|
||||
|
@ -125,7 +122,8 @@ message SearchRequest {
|
|||
bool is_advanced = 20;
|
||||
int64 offset = 21;
|
||||
common.ConsistencyLevel consistency_level = 22;
|
||||
ExtraSearchParam extra_search_param = 23;
|
||||
int64 group_by_field_id = 23;
|
||||
int64 group_size = 24;
|
||||
}
|
||||
|
||||
message SubSearchResults {
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
@ -20,54 +20,137 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type reduceSearchResultInfo struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
nq int64
|
||||
topK int64
|
||||
metricType string
|
||||
pkType schemapb.DataType
|
||||
offset int64
|
||||
queryInfo *planpb.QueryInfo
|
||||
}
|
||||
|
||||
func NewReduceSearchResultInfo(
|
||||
subSearchResultData []*schemapb.SearchResultData,
|
||||
nq int64,
|
||||
topK int64,
|
||||
metricType string,
|
||||
pkType schemapb.DataType,
|
||||
offset int64,
|
||||
queryInfo *planpb.QueryInfo,
|
||||
) *reduceSearchResultInfo {
|
||||
return &reduceSearchResultInfo{
|
||||
subSearchResultData: subSearchResultData,
|
||||
nq: nq,
|
||||
topK: topK,
|
||||
metricType: metricType,
|
||||
pkType: pkType,
|
||||
offset: offset,
|
||||
queryInfo: queryInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) {
|
||||
if reduceInfo.queryInfo.GroupByFieldId > 0 {
|
||||
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
|
||||
if reduceInfo.GetGroupByFieldId() > 0 {
|
||||
if reduceInfo.GetIsAdvance() {
|
||||
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
||||
// because the final score has not been accumulated, also, offset cannot be applied
|
||||
return reduceAdvanceGroupBY(ctx,
|
||||
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
||||
}
|
||||
return reduceSearchResultDataWithGroupBy(ctx,
|
||||
reduceInfo.subSearchResultData,
|
||||
reduceInfo.nq,
|
||||
reduceInfo.topK,
|
||||
reduceInfo.metricType,
|
||||
reduceInfo.pkType,
|
||||
reduceInfo.offset,
|
||||
reduceInfo.queryInfo.GroupSize)
|
||||
subSearchResultData,
|
||||
reduceInfo.GetNq(),
|
||||
reduceInfo.GetTopK(),
|
||||
reduceInfo.GetMetricType(),
|
||||
reduceInfo.GetPkType(),
|
||||
reduceInfo.GetOffset(),
|
||||
reduceInfo.GetGroupSize())
|
||||
}
|
||||
return reduceSearchResultDataNoGroupBy(ctx,
|
||||
reduceInfo.subSearchResultData,
|
||||
reduceInfo.nq,
|
||||
reduceInfo.topK,
|
||||
reduceInfo.metricType,
|
||||
reduceInfo.pkType,
|
||||
reduceInfo.offset)
|
||||
subSearchResultData,
|
||||
reduceInfo.GetNq(),
|
||||
reduceInfo.GetTopK(),
|
||||
reduceInfo.GetMetricType(),
|
||||
reduceInfo.GetPkType(),
|
||||
reduceInfo.GetOffset())
|
||||
}
|
||||
|
||||
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||
nq int64, topK int64,
|
||||
) (int64, int, error) {
|
||||
var allSearchCount int64
|
||||
var hitNum int
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
allSearchCount += sData.GetAllSearchCount()
|
||||
hitNum += pkLength
|
||||
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return allSearchCount, hitNum, err
|
||||
}
|
||||
}
|
||||
return allSearchCount, hitNum, nil
|
||||
}
|
||||
|
||||
func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
||||
// for advance group by, offset is not applied, so just return when there's only one channel
|
||||
if len(subSearchResultData) == 1 {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: subSearchResultData[0],
|
||||
}, nil
|
||||
}
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topK,
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
var limit int64
|
||||
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
} else {
|
||||
ret.GetResults().AllSearchCount = allSearchCount
|
||||
limit = int64(hitNum)
|
||||
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit)
|
||||
}
|
||||
|
||||
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
// reducing nq * topk results
|
||||
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
||||
dataCount := int64(0)
|
||||
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
|
||||
subData := subSearchResultData[subIdx]
|
||||
subPks := subData.GetIds()
|
||||
subScores := subData.GetScores()
|
||||
subGroupByVals := subData.GetGroupByFieldValue()
|
||||
|
||||
nqTopK := subData.Topks[nqIdx]
|
||||
for i := int64(0); i < nqTopK; i++ {
|
||||
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
||||
pk := typeutil.GetPK(subPks, innerIdx)
|
||||
score := subScores[innerIdx]
|
||||
groupByVal := typeutil.GetData(subData.GetGroupByFieldValue(), int(innerIdx))
|
||||
typeutil.AppendPKs(ret.Results.Ids, pk)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subGroupByVals.GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
dataCount += 1
|
||||
}
|
||||
}
|
||||
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
||||
}
|
||||
|
||||
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type MilvusPKType interface{}
|
||||
|
@ -109,37 +192,16 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0, limit),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
groupBound := groupSize * limit
|
||||
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
||||
return ret, nil
|
||||
}
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
|
||||
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
} else {
|
||||
ret.GetResults().AllSearchCount = allSearchCount
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -163,7 +225,6 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
groupBound := groupSize * limit
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
|
@ -298,36 +359,15 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
|||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0, limit),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||
return ret, nil
|
||||
}
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
|
||||
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
} else {
|
||||
ret.GetResults().AllSearchCount = allSearchCount
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -428,23 +468,215 @@ func rankSearchResultData(ctx context.Context,
|
|||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
groupByFieldID int64,
|
||||
groupSize int64,
|
||||
groupScorer func(group *Group) error,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
||||
if groupByFieldID > 0 {
|
||||
return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize)
|
||||
}
|
||||
return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults)
|
||||
}
|
||||
|
||||
func compareKey(keyI interface{}, keyJ interface{}) bool {
|
||||
switch keyI.(type) {
|
||||
case int64:
|
||||
return keyI.(int64) < keyJ.(int64)
|
||||
case string:
|
||||
return keyI.(string) < keyJ.(string)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetGroupScorer(scorerType string) (func(group *Group) error, error) {
|
||||
switch scorerType {
|
||||
case MaxScorer:
|
||||
return func(group *Group) error {
|
||||
group.finalScore = group.maxScore
|
||||
return nil
|
||||
}, nil
|
||||
case SumScorer:
|
||||
return func(group *Group) error {
|
||||
group.finalScore = group.sumScore
|
||||
return nil
|
||||
}, nil
|
||||
case AvgScorer:
|
||||
return func(group *Group) error {
|
||||
if len(group.idList) == 0 {
|
||||
return merr.WrapErrParameterInvalid(1, len(group.idList),
|
||||
"input group for score must have at least one id, must be sth wrong within code")
|
||||
}
|
||||
group.finalScore = group.sumScore / float32(len(group.idList))
|
||||
return nil
|
||||
}, nil
|
||||
default:
|
||||
return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
|
||||
}
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
idList []interface{}
|
||||
scoreList []float32
|
||||
groupVal interface{}
|
||||
maxScore float32
|
||||
sumScore float32
|
||||
finalScore float32
|
||||
}
|
||||
|
||||
func rankSearchResultDataByGroup(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
groupScorer func(group *Group) error,
|
||||
groupSize int64,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset := params.offset
|
||||
limit := params.limit
|
||||
topk := limit + offset
|
||||
roundDecimal := params.roundDecimal
|
||||
log.Ctx(ctx).Debug("rankSearchResultData",
|
||||
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||
// in the context of group by, the meaning for offset/limit/top refers to related numbers of group
|
||||
groupTopK := limit + offset
|
||||
log.Ctx(ctx).Debug("rankSearchResultDataByGroup",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
var ret *milvuspb.SearchResults
|
||||
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
totalCount := limit * groupSize
|
||||
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
type accumulateIDGroupVal struct {
|
||||
accumulatedScore float32
|
||||
groupVal interface{}
|
||||
}
|
||||
|
||||
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
|
||||
}
|
||||
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := 0
|
||||
// milvus has limits for the value range of nq and limit
|
||||
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
|
||||
for i := 0; i < int(nq); i++ {
|
||||
realTopK := int(result.GetResults().Topks[i])
|
||||
for j := start; j < start+realTopK; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
||||
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
|
||||
if accumulatedScores[i][id] != nil {
|
||||
accumulatedScores[i][id].accumulatedScore += scores[j]
|
||||
} else {
|
||||
accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal}
|
||||
}
|
||||
}
|
||||
start += realTopK
|
||||
}
|
||||
}
|
||||
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
big := func(i, j int) bool {
|
||||
scoreItemI := idSet[keys[i]]
|
||||
scoreItemJ := idSet[keys[j]]
|
||||
if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore {
|
||||
return compareKey(keys[i], keys[j])
|
||||
}
|
||||
return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore
|
||||
}
|
||||
sort.Slice(keys, big)
|
||||
|
||||
// separate keys into buckets according to groupVal
|
||||
buckets := make(map[interface{}]*Group)
|
||||
for _, key := range keys {
|
||||
scoreItem := idSet[key]
|
||||
groupVal := scoreItem.groupVal
|
||||
if buckets[groupVal] == nil {
|
||||
buckets[groupVal] = &Group{
|
||||
idList: make([]interface{}, 0),
|
||||
scoreList: make([]float32, 0),
|
||||
groupVal: groupVal,
|
||||
}
|
||||
}
|
||||
if int64(len(buckets[groupVal].idList)) >= groupSize {
|
||||
// only consider group size results in each group
|
||||
continue
|
||||
}
|
||||
buckets[groupVal].idList = append(buckets[groupVal].idList, key)
|
||||
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore)
|
||||
if scoreItem.accumulatedScore > buckets[groupVal].maxScore {
|
||||
buckets[groupVal].maxScore = scoreItem.accumulatedScore
|
||||
}
|
||||
buckets[groupVal].sumScore += scoreItem.accumulatedScore
|
||||
}
|
||||
if int64(len(buckets)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
groupList := make([]*Group, len(buckets))
|
||||
idx := 0
|
||||
for _, group := range buckets {
|
||||
groupScorer(group)
|
||||
groupList[idx] = group
|
||||
idx += 1
|
||||
}
|
||||
sort.Slice(groupList, func(i, j int) bool {
|
||||
if groupList[i].finalScore == groupList[j].finalScore {
|
||||
if len(groupList[i].idList) == len(groupList[j].idList) {
|
||||
// if final score and size of group are both equal
|
||||
// choose the group with smaller first key
|
||||
// here, it's guaranteed all group having at least one id in the idList
|
||||
return compareKey(groupList[i].idList[0], groupList[j].idList[0])
|
||||
}
|
||||
// choose the larger group when scores are equal
|
||||
return len(groupList[i].idList) > len(groupList[j].idList)
|
||||
}
|
||||
return groupList[i].finalScore > groupList[j].finalScore
|
||||
})
|
||||
|
||||
if int64(len(groupList)) > groupTopK {
|
||||
groupList = groupList[:groupTopK]
|
||||
}
|
||||
returnedRowNum := 0
|
||||
for index := int(offset); index < len(groupList); index++ {
|
||||
group := groupList[index]
|
||||
for i, score := range group.scoreList {
|
||||
// idList and scoreList must have same length
|
||||
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
|
||||
}
|
||||
returnedRowNum += len(group.idList)
|
||||
}
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
|
@ -455,22 +687,54 @@ func rankSearchResultData(ctx context.Context,
|
|||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
Data: make([]int64, 0, capacity),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0),
|
||||
Data: make([]string, 0, capacity),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
return errors.New("unsupported pk type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rankSearchResultDataByPk(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||
topk := limit + offset
|
||||
log.Ctx(ctx).Debug("rankSearchResultDataByPk",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
var ret *milvuspb.SearchResults
|
||||
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// []map[id]score
|
||||
|
@ -503,20 +767,10 @@ func rankSearchResultData(ctx context.Context,
|
|||
continue
|
||||
}
|
||||
|
||||
compareKeys := func(keyI, keyJ interface{}) bool {
|
||||
switch keyI.(type) {
|
||||
case int64:
|
||||
return keyI.(int64) < keyJ.(int64)
|
||||
case string:
|
||||
return keyI.(string) < keyJ.(string)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
big := func(i, j int) bool {
|
||||
if idSet[keys[i]] == idSet[keys[j]] {
|
||||
return compareKeys(keys[i], keys[j])
|
||||
return compareKey(keys[i], keys[j])
|
||||
}
|
||||
return idSet[keys[i]] > idSet[keys[j]]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type SearchReduceUtilTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||
var searchResultData1 *schemapb.SearchResultData
|
||||
var searchResultData2 *schemapb.SearchResultData
|
||||
|
||||
{
|
||||
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||
searchResultData1 = &schemapb.SearchResultData{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"7", "5", "4", "2", "3", "6", "1", "9", "8"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: []int64{9},
|
||||
Scores: []float32{0.6, 0.53, 0.52, 0.43, 0.41, 0.33, 0.30, 0.27, 0.22},
|
||||
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
groupFieldValue := []string{"www", "aaa", "ccc", "www", "www", "ccc", "aaa", "ccc", "aaa"}
|
||||
searchResultData2 = &schemapb.SearchResultData{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"17", "15", "14", "12", "13", "16", "11", "19", "18"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: []int64{9},
|
||||
Scores: []float32{0.7, 0.43, 0.32, 0.32, 0.31, 0.31, 0.30, 0.30, 0.30},
|
||||
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||
}
|
||||
}
|
||||
|
||||
searchResults := []*milvuspb.SearchResults{
|
||||
{Results: searchResultData1},
|
||||
{Results: searchResultData2},
|
||||
}
|
||||
|
||||
nq := int64(1)
|
||||
limit := int64(3)
|
||||
offset := int64(0)
|
||||
roundDecimal := int64(1)
|
||||
groupSize := int64(3)
|
||||
groupByFieldId := int64(101)
|
||||
rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal}
|
||||
|
||||
{
|
||||
// test for sum group scorer
|
||||
scorerType := "sum"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for max group scorer
|
||||
scorerType := "max"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for avg group scorer
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for offset for ranking group
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankParams.offset = 2
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for offset exceeding the count of final groups
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankParams.offset = 4
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{}, rankedRes.GetResults().GetScores())
|
||||
}
|
||||
|
||||
{
|
||||
// test for invalid group scorer
|
||||
scorerType := "xxx"
|
||||
groupScorer, err := GetGroupScorer(scorerType)
|
||||
struts.Error(err)
|
||||
struts.Nil(groupScorer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchReduceUtilTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(SearchReduceUtilTestSuite))
|
||||
}
|
|
@ -310,13 +310,15 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||
}
|
||||
|
||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
|
||||
copy(searchParams, req.GetRankParams())
|
||||
ret := &milvuspb.SearchRequest{
|
||||
Base: req.GetBase(),
|
||||
DbName: req.GetDbName(),
|
||||
CollectionName: req.GetCollectionName(),
|
||||
PartitionNames: req.GetPartitionNames(),
|
||||
OutputFields: req.GetOutputFields(),
|
||||
SearchParams: req.GetRankParams(),
|
||||
SearchParams: searchParams,
|
||||
TravelTimestamp: req.GetTravelTimestamp(),
|
||||
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
||||
Nq: 0,
|
||||
|
|
|
@ -42,6 +42,12 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
SumScorer string = "sum"
|
||||
MaxScorer string = "max"
|
||||
AvgScorer string = "avg"
|
||||
)
|
||||
|
||||
const (
|
||||
IgnoreGrowingKey = "ignore_growing"
|
||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||
|
@ -49,6 +55,7 @@ const (
|
|||
GroupByFieldKey = "group_by_field"
|
||||
GroupSizeKey = "group_size"
|
||||
GroupStrictSize = "group_strict_size"
|
||||
RankGroupScorer = "rank_group_scorer"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
|
@ -76,8 +77,9 @@ type searchTask struct {
|
|||
queryInfos []*planpb.QueryInfo
|
||||
relatedDataSize int64
|
||||
|
||||
reScorers []reScorer
|
||||
rankParams *rankParams
|
||||
reScorers []reScorer
|
||||
rankParams *rankParams
|
||||
groupScorer func(group *Group) error
|
||||
}
|
||||
|
||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
|
@ -339,10 +341,9 @@ func setQueryInfoIfMvEnable(queryInfo *planpb.QueryInfo, t *searchTask, plan *pl
|
|||
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
|
||||
defer sp.End()
|
||||
|
||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
|
||||
// fetch search_growing from search param
|
||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
|
@ -351,9 +352,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryInfo.GetGroupByFieldId() != -1 {
|
||||
return errors.New("not support search_group_by operation in the hybrid search")
|
||||
}
|
||||
|
||||
internalSubReq := &internalpb.SubSearchRequest{
|
||||
Dsl: subReq.GetDsl(),
|
||||
PlaceholderGroup: subReq.GetPlaceholderGroup(),
|
||||
|
@ -364,6 +363,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
Topk: queryInfo.GetTopk(),
|
||||
Offset: offset,
|
||||
MetricType: queryInfo.GetMetricType(),
|
||||
GroupByFieldId: queryInfo.GetGroupByFieldId(),
|
||||
GroupSize: queryInfo.GetGroupSize(),
|
||||
}
|
||||
|
||||
// set PartitionIDs for sub search
|
||||
|
@ -403,6 +404,11 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
if len(t.queryInfos) > 0 {
|
||||
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
|
||||
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
|
||||
}
|
||||
|
||||
// used for requery
|
||||
if t.partitionKeyMode {
|
||||
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
||||
|
@ -413,6 +419,18 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// set up groupScorer for hybridsearch+groupBy
|
||||
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
groupScorerStr = MaxScorer
|
||||
}
|
||||
groupScorer, err := GetGroupScorer(groupScorerStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.groupScorer = groupScorer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -461,7 +479,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.queryInfos = append(t.queryInfos, queryInfo)
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize}
|
||||
t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId
|
||||
t.SearchRequest.GroupSize = queryInfo.GroupSize
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
|
@ -554,7 +573,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) {
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
|
||||
metricType := ""
|
||||
if len(toReduceResults) >= 1 {
|
||||
metricType = toReduceResults[0].GetMetricType()
|
||||
|
@ -585,8 +604,8 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
|
|||
return nil, err
|
||||
}
|
||||
var result *milvuspb.SearchResults
|
||||
result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK,
|
||||
metricType, primaryFieldSchema.DataType, offset, queryInfo))
|
||||
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()).
|
||||
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return nil, err
|
||||
|
@ -647,7 +666,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||
}
|
||||
}
|
||||
|
||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
|
@ -656,7 +674,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
if len(internalResults) >= 1 {
|
||||
metricType = internalResults[0].GetMetricType()
|
||||
}
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index])
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -667,13 +685,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
multipleMilvusResults)
|
||||
multipleMilvusResults,
|
||||
t.SearchRequest.GetGroupByFieldId(),
|
||||
t.SearchRequest.GetGroupSize(),
|
||||
t.groupScorer)
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0])
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -914,7 +935,7 @@ func decodeSearchResults(ctx context.Context, searchResults []*internalpb.Search
|
|||
return results, nil
|
||||
}
|
||||
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
||||
if data.NumQueries != nq {
|
||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||
}
|
||||
|
@ -922,7 +943,6 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
|
|||
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
||||
}
|
||||
|
||||
pkHitNum := typeutil.GetSizeOfIDs(data.GetIds())
|
||||
if len(data.Scores) != pkHitNum {
|
||||
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
||||
len(data.Scores), pkHitNum)
|
||||
|
|
|
@ -39,6 +39,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
@ -1247,7 +1248,8 @@ func Test_checkSearchResultData(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk)
|
||||
pkLength := typeutil.GetSizeOfIDs(test.args.data.GetIds())
|
||||
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk, pkLength)
|
||||
|
||||
if test.wantErr {
|
||||
assert.Error(t, err)
|
||||
|
@ -1522,8 +1524,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResult(context.TODO(),
|
||||
NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).
|
||||
WithOffset(test.offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1574,8 +1577,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
}
|
||||
for _, test := range lessThanLimitTests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk,
|
||||
metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(test.offset).
|
||||
WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1603,9 +1607,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
GroupByFieldId: -1,
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(
|
||||
results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo))
|
||||
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
|
||||
|
@ -1633,9 +1636,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
queryInfo := &planpb.QueryInfo{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results,
|
||||
nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_VarChar).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||
|
@ -1708,8 +1710,8 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||
GroupByFieldId: 1,
|
||||
GroupSize: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
||||
schemapb.DataType_Int64, 0, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
|
@ -1768,8 +1770,8 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
|
|||
GroupByFieldId: 1,
|
||||
GroupSize: 1,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
|
||||
schemapb.DataType_Int64, offset, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, limit+offset).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
|
@ -1842,8 +1844,9 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
|
|||
GroupByFieldId: 1,
|
||||
GroupSize: 2,
|
||||
}
|
||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
||||
schemapb.DataType_Int64, 0, queryInfo))
|
||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
|
@ -1855,6 +1858,188 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceAdvanceSearchGroupBy(t *testing.T) {
|
||||
groupByField := int64(101)
|
||||
nq := int64(1)
|
||||
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||
topK := int64(3)
|
||||
{
|
||||
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
|
||||
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
|
||||
tops := []int64{9}
|
||||
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||
result1 := &schemapb.SearchResultData{
|
||||
Scores: scores,
|
||||
TopK: topK,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
NumQueries: nq,
|
||||
Topks: tops,
|
||||
GroupByFieldValue: groupByVals,
|
||||
}
|
||||
subSearchResultData = append(subSearchResultData, result1)
|
||||
}
|
||||
{
|
||||
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}
|
||||
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33, 27}
|
||||
tops := []int64{9}
|
||||
groupFieldValue := []string{"xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}
|
||||
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||
result2 := &schemapb.SearchResultData{
|
||||
TopK: topK,
|
||||
Scores: scores,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
NumQueries: nq,
|
||||
GroupByFieldValue: groupByVals,
|
||||
}
|
||||
subSearchResultData = append(subSearchResultData, result2)
|
||||
}
|
||||
groupSize := int64(3)
|
||||
|
||||
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||
assert.NoError(t, err)
|
||||
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||
assert.Equal(t, 18, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||
assert.Equal(t, 18, len(reducedRes.GetResults().GetScores()))
|
||||
assert.Equal(t, 18, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||
assert.Equal(t, []int64{18}, reducedRes.GetResults().GetTopks())
|
||||
|
||||
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37, 17, 15, 16, 21, 32, 24, 41, 33, 27}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43, 0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}, reducedRes.GetResults().GetScores())
|
||||
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa", "xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceAdvanceSearchGroupByShortCut(t *testing.T) {
|
||||
groupByField := int64(101)
|
||||
nq := int64(1)
|
||||
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||
topK := int64(3)
|
||||
{
|
||||
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
|
||||
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
|
||||
tops := []int64{9}
|
||||
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||
result1 := &schemapb.SearchResultData{
|
||||
Scores: scores,
|
||||
TopK: topK,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
NumQueries: nq,
|
||||
Topks: tops,
|
||||
GroupByFieldValue: groupByVals,
|
||||
}
|
||||
subSearchResultData = append(subSearchResultData, result1)
|
||||
}
|
||||
groupSize := int64(3)
|
||||
|
||||
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||
|
||||
assert.NoError(t, err)
|
||||
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||
assert.Equal(t, 9, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||
assert.Equal(t, 9, len(reducedRes.GetResults().GetScores()))
|
||||
assert.Equal(t, 9, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||
assert.Equal(t, []int64{9}, reducedRes.GetResults().GetTopks())
|
||||
|
||||
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}, reducedRes.GetResults().GetScores())
|
||||
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceAdvanceSearchGroupByMultipleNq(t *testing.T) {
|
||||
groupByField := int64(101)
|
||||
nq := int64(2)
|
||||
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||
topK := int64(2)
|
||||
groupSize := int64(2)
|
||||
{
|
||||
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.51, 0.5, 0.45, 0.43}
|
||||
ids := []int64{7, 5, 6, 11, 14, 31, 23, 37}
|
||||
tops := []int64{4, 4}
|
||||
groupFieldValue := []string{"ccc", "bbb", "ccc", "bbb", "aaa", "xxx", "xxx", "aaa"}
|
||||
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||
result1 := &schemapb.SearchResultData{
|
||||
Scores: scores,
|
||||
TopK: topK,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
NumQueries: nq,
|
||||
Topks: tops,
|
||||
GroupByFieldValue: groupByVals,
|
||||
}
|
||||
subSearchResultData = append(subSearchResultData, result1)
|
||||
}
|
||||
{
|
||||
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51}
|
||||
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33}
|
||||
tops := []int64{4, 4}
|
||||
groupFieldValue := []string{"ddd", "bbb", "ddd", "bbb", "rrr", "sss", "rrr", "sss"}
|
||||
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||
result2 := &schemapb.SearchResultData{
|
||||
TopK: topK,
|
||||
Scores: scores,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
NumQueries: nq,
|
||||
GroupByFieldValue: groupByVals,
|
||||
}
|
||||
subSearchResultData = append(subSearchResultData, result2)
|
||||
}
|
||||
|
||||
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||
assert.NoError(t, err)
|
||||
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||
assert.Equal(t, 16, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||
assert.Equal(t, 16, len(reducedRes.GetResults().GetScores()))
|
||||
assert.Equal(t, 16, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||
|
||||
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||
assert.Equal(t, []int64{8, 8}, reducedRes.GetResults().GetTopks())
|
||||
|
||||
assert.Equal(t, []int64{7, 5, 6, 11, 17, 15, 16, 21, 14, 31, 23, 37, 32, 24, 41, 33}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.83, 0.72, 0.72, 0.65, 0.51, 0.5, 0.45, 0.43, 0.63, 0.55, 0.52, 0.51}, reducedRes.GetResults().GetScores())
|
||||
assert.Equal(t, []string{"ccc", "bbb", "ccc", "bbb", "ddd", "bbb", "ddd", "bbb", "aaa", "xxx", "xxx", "aaa", "rrr", "sss", "rrr", "sss"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
|
||||
fmt.Println(reducedRes.GetResults().Ids.GetIntId().Data)
|
||||
fmt.Println(reducedRes.GetResults().GetScores())
|
||||
fmt.Println(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
|
|
|
@ -43,6 +43,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
|
@ -332,6 +333,8 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
|
||||
Username: req.GetReq().GetUsername(),
|
||||
IsAdvanced: false,
|
||||
GroupByFieldId: subReq.GetGroupByFieldId(),
|
||||
GroupSize: subReq.GetGroupSize(),
|
||||
}
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
searchReq := &querypb.SearchRequest{
|
||||
|
@ -350,14 +353,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return segments.ReduceSearchResults(ctx,
|
||||
return segments.ReduceSearchOnQueryNode(ctx,
|
||||
results,
|
||||
segments.NewReduceInfo(searchReq.Req.GetNq(),
|
||||
searchReq.Req.GetTopk(),
|
||||
searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
searchReq.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
searchReq.Req.GetMetricType()),
|
||||
)
|
||||
reduce.NewReduceSearchResultInfo(searchReq.GetReq().GetNq(),
|
||||
searchReq.GetReq().GetTopk()).WithMetricType(searchReq.GetReq().GetMetricType()).
|
||||
WithGroupByField(searchReq.GetReq().GetGroupByFieldId()).
|
||||
WithGroupSize(searchReq.GetReq().GetGroupSize()))
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
@ -376,12 +377,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
}
|
||||
results[i] = result
|
||||
}
|
||||
var ret *internalpb.SearchResults
|
||||
ret, err = segments.MergeToAdvancedResults(ctx, results)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []*internalpb.SearchResults{ret}, nil
|
||||
return results, nil
|
||||
}
|
||||
return sd.search(ctx, req, sealed, growing)
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
|
@ -384,16 +385,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
var resp *internalpb.SearchResults
|
||||
if req.GetReq().GetIsAdvanced() {
|
||||
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
|
||||
} else {
|
||||
resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(),
|
||||
req.Req.GetTopk(),
|
||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
req.Req.GetMetricType()))
|
||||
}
|
||||
resp, err := segments.ReduceSearchOnQueryNode(ctx, results,
|
||||
reduce.NewReduceSearchResultInfo(req.GetReq().GetNq(),
|
||||
req.GetReq().GetTopk()).WithMetricType(req.GetReq().GetMetricType()).WithGroupByField(req.GetReq().GetGroupByFieldId()).
|
||||
WithGroupSize(req.GetReq().GetGroupByFieldId()).WithAdvance(req.GetReq().GetIsAdvanced()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
|
@ -42,7 +43,14 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}
|
|||
|
||||
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
|
||||
|
||||
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) {
|
||||
func ReduceSearchOnQueryNode(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
|
||||
if info.GetIsAdvance() {
|
||||
return ReduceAdvancedSearchResults(ctx, results)
|
||||
}
|
||||
return ReduceSearchResults(ctx, results, info)
|
||||
}
|
||||
|
||||
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
|
||||
results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
|
||||
return result != nil && result.GetSlicedBlob() != nil
|
||||
})
|
||||
|
@ -60,8 +68,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
channelsMvcc[ch] = ts
|
||||
}
|
||||
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
|
||||
if info.metricType == "" {
|
||||
info.metricType = r.MetricType
|
||||
if info.GetMetricType() == "" {
|
||||
info.SetMetricType(r.MetricType)
|
||||
}
|
||||
}
|
||||
log := log.Ctx(ctx)
|
||||
|
@ -86,7 +94,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
log.Warn("shard leader reduce errors", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType)
|
||||
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.GetNq(), info.GetTopK(), info.GetMetricType())
|
||||
if err != nil {
|
||||
log.Warn("shard leader encode search result errors", zap.Error(err))
|
||||
return nil, err
|
||||
|
@ -115,7 +123,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
return searchResults, nil
|
||||
}
|
||||
|
||||
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) {
|
||||
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
|
||||
_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults")
|
||||
defer sp.End()
|
||||
|
||||
|
@ -129,53 +137,14 @@ func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.Sear
|
|||
IsAdvanced: true,
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
if !result.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
// we just append here, no need to split subResult and reduce
|
||||
// defer this reduce to proxy
|
||||
searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...)
|
||||
searchResults.NumQueries = result.GetNumQueries()
|
||||
}
|
||||
searchResults.ChannelsMvcc = channelsMvcc
|
||||
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
|
||||
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
})
|
||||
searchResults.CostAggregation = mergeRequestCost(requestCosts)
|
||||
if searchResults.CostAggregation == nil {
|
||||
searchResults.CostAggregation = &internalpb.CostAggregation{}
|
||||
}
|
||||
searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
|
||||
return searchResults, nil
|
||||
}
|
||||
|
||||
func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
|
||||
searchResults := &internalpb.SearchResults{
|
||||
IsAdvanced: true,
|
||||
}
|
||||
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
relatedDataSize := int64(0)
|
||||
for index, result := range results {
|
||||
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
searchResults.NumQueries = result.GetNumQueries()
|
||||
// we just append here, no need to split subResult and reduce
|
||||
// defer this reduce to proxy
|
||||
// defer this reduction to proxy
|
||||
subResult := &internalpb.SubSearchResults{
|
||||
MetricType: result.GetMetricType(),
|
||||
NumQueries: result.GetNumQueries(),
|
||||
|
@ -185,7 +154,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes
|
|||
SlicedOffset: result.GetSlicedOffset(),
|
||||
ReqIndex: int64(index),
|
||||
}
|
||||
searchResults.NumQueries = result.GetNumQueries()
|
||||
searchResults.SubResults = append(searchResults.SubResults, subResult)
|
||||
}
|
||||
searchResults.ChannelsMvcc = channelsMvcc
|
||||
|
|
|
@ -29,7 +29,9 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
@ -886,6 +888,42 @@ func (suite *ResultSuite) TestSort() {
|
|||
}, result.FieldsData[9].GetScalars().GetArrayData().GetData())
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestReduceSearchOnQueryNode() {
|
||||
results := make([]*internalpb.SearchResults, 0)
|
||||
metricType := metric.IP
|
||||
nq := int64(1)
|
||||
topK := int64(1)
|
||||
mockBlob := []byte{65, 66, 67, 65, 66, 67}
|
||||
{
|
||||
subRes1 := &internalpb.SearchResults{
|
||||
MetricType: metricType,
|
||||
NumQueries: nq,
|
||||
TopK: topK,
|
||||
SlicedBlob: mockBlob,
|
||||
}
|
||||
results = append(results, subRes1)
|
||||
}
|
||||
{
|
||||
subRes2 := &internalpb.SearchResults{
|
||||
MetricType: metricType,
|
||||
NumQueries: nq,
|
||||
TopK: topK,
|
||||
SlicedBlob: mockBlob,
|
||||
}
|
||||
results = append(results, subRes2)
|
||||
}
|
||||
reducedRes, err := ReduceSearchOnQueryNode(context.Background(), results, reduce.NewReduceSearchResultInfo(nq, topK).
|
||||
WithMetricType(metricType).WithPkType(schemapb.DataType_Int8).WithAdvance(true))
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, len(reducedRes.GetSubResults()))
|
||||
|
||||
subRes1 := reducedRes.GetSubResults()[0]
|
||||
suite.Equal(metricType, subRes1.GetMetricType())
|
||||
suite.Equal(nq, subRes1.GetNumQueries())
|
||||
suite.Equal(topK, subRes1.GetTopK())
|
||||
suite.Equal(mockBlob, subRes1.GetSlicedBlob())
|
||||
}
|
||||
|
||||
func TestResult_MergeRequestCost(t *testing.T) {
|
||||
costs := []*internalpb.CostAggregation{
|
||||
{
|
||||
|
|
|
@ -8,39 +8,28 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type ReduceInfo struct {
|
||||
nq int64
|
||||
topK int64
|
||||
groupByFieldID int64
|
||||
groupSize int64
|
||||
metricType string
|
||||
}
|
||||
|
||||
func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo {
|
||||
return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric}
|
||||
}
|
||||
|
||||
type SearchReduce interface {
|
||||
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error)
|
||||
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error)
|
||||
}
|
||||
|
||||
type SearchCommonReduce struct{}
|
||||
|
||||
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
||||
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
if len(searchResultData) == 0 {
|
||||
return &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
NumQueries: info.GetNq(),
|
||||
TopK: info.GetTopK(),
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
|
@ -48,8 +37,8 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||
}, nil
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
NumQueries: info.GetNq(),
|
||||
TopK: info.GetTopK(),
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
|
@ -59,7 +48,7 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
for j := int64(1); j < info.nq; j++ {
|
||||
for j := int64(1); j < info.GetNq(); j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
|
@ -68,11 +57,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||
var skipDupCnt int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for i := int64(0); i < info.nq; i++ {
|
||||
for i := int64(0); i < info.GetNq(); i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
idSet := make(map[interface{}]struct{})
|
||||
var j int64
|
||||
for j = 0; j < info.topK; {
|
||||
for j = 0; j < info.GetTopK(); {
|
||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
|
@ -113,15 +102,15 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||
|
||||
type SearchGroupByReduce struct{}
|
||||
|
||||
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
||||
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
if len(searchResultData) == 0 {
|
||||
return &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
NumQueries: info.GetNq(),
|
||||
TopK: info.GetTopK(),
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
|
@ -129,8 +118,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||
}, nil
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: info.nq,
|
||||
TopK: info.topK,
|
||||
NumQueries: info.GetNq(),
|
||||
TopK: info.GetTopK(),
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
Ids: &schemapb.IDs{},
|
||||
|
@ -140,7 +129,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
for j := int64(1); j < info.nq; j++ {
|
||||
for j := int64(1); j < info.GetNq(); j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
|
@ -149,13 +138,13 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||
var filteredCount int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
groupSize := info.groupSize
|
||||
groupSize := info.GetGroupSize()
|
||||
if groupSize <= 0 {
|
||||
groupSize = 1
|
||||
}
|
||||
groupBound := info.topK * groupSize
|
||||
groupBound := info.GetTopK() * groupSize
|
||||
|
||||
for i := int64(0); i < info.nq; i++ {
|
||||
for i := int64(0); i < info.GetNq(); i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
|
@ -178,7 +167,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||
}
|
||||
|
||||
groupCount := groupByValueMap[groupByVal]
|
||||
if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK {
|
||||
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
|
||||
// exceed the limit for group count, filter this entity
|
||||
filteredCount++
|
||||
} else if groupCount >= groupSize {
|
||||
|
@ -219,8 +208,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func InitSearchReducer(info *ReduceInfo) SearchReduce {
|
||||
if info.groupByFieldID > 0 {
|
||||
func InitSearchReducer(info *reduce.ResultInfo) SearchReduce {
|
||||
if info.GetGroupByFieldId() > 0 {
|
||||
return &SearchGroupByReduce{}
|
||||
}
|
||||
return &SearchCommonReduce{}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
|
@ -28,7 +29,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -47,7 +48,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -96,7 +97,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -140,7 +141,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -184,7 +185,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -228,7 +229,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
@ -239,7 +240,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||
|
||||
suite.Run("reduce_group_by_empty_input", func() {
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
||||
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
|
||||
searchReduce := InitSearchReducer(reduceInfo)
|
||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||
suite.Nil(err)
|
||||
|
|
|
@ -753,69 +753,41 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
|
||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||
|
||||
for i, ch := range req.GetDmlChannels() {
|
||||
ch := ch
|
||||
req := &querypb.SearchRequest{
|
||||
Req: req.Req,
|
||||
DmlChannels: []string{ch},
|
||||
SegmentIDs: req.SegmentIDs,
|
||||
Scope: req.Scope,
|
||||
TotalChannelNum: req.TotalChannelNum,
|
||||
}
|
||||
|
||||
i := i
|
||||
runningGp.Go(func() error {
|
||||
ret, err := node.searchChannel(runningCtx, req, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := merr.Error(ret.GetStatus()); err != nil {
|
||||
return err
|
||||
}
|
||||
toReduceResults[i] = ret
|
||||
return nil
|
||||
})
|
||||
if len(req.GetDmlChannels()) != 1 {
|
||||
err := merr.WrapErrParameterInvalid(1, len(req.GetDmlChannels()), "count of channel to be searched should only be 1, wrong code")
|
||||
resp.Status = merr.Status(err)
|
||||
log.Warn("got wrong number of channels to be searched", zap.Error(err))
|
||||
return resp, nil
|
||||
}
|
||||
if err := runningGp.Wait(); err != nil {
|
||||
|
||||
ch := req.GetDmlChannels()[0]
|
||||
channelReq := &querypb.SearchRequest{
|
||||
Req: req.Req,
|
||||
DmlChannels: []string{ch},
|
||||
SegmentIDs: req.SegmentIDs,
|
||||
Scope: req.Scope,
|
||||
TotalChannelNum: req.TotalChannelNum,
|
||||
}
|
||||
ret, err := node.searchChannel(ctx, channelReq, ch)
|
||||
if err != nil {
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
var result *internalpb.SearchResults
|
||||
var err2 error
|
||||
if req.GetReq().GetIsAdvanced() {
|
||||
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
|
||||
} else {
|
||||
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(),
|
||||
req.Req.GetTopk(),
|
||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
||||
req.Req.GetMetricType()))
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err2))
|
||||
resp.Status = merr.Status(err2)
|
||||
return resp, nil
|
||||
}
|
||||
result.Status = merr.Success()
|
||||
ret.Status = merr.Success()
|
||||
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.
|
||||
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
||||
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel).
|
||||
Add(float64(proto.Size(req)))
|
||||
|
||||
if result.GetCostAggregation() != nil {
|
||||
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||
if ret.GetCostAggregation() != nil {
|
||||
ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||
}
|
||||
return result, nil
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// only used for delegator query segments from worker
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
package reduce
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type ResultInfo struct {
|
||||
nq int64
|
||||
topK int64
|
||||
metricType string
|
||||
pkType schemapb.DataType
|
||||
offset int64
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
isAdvance bool
|
||||
}
|
||||
|
||||
func NewReduceSearchResultInfo(
|
||||
nq int64,
|
||||
topK int64,
|
||||
) *ResultInfo {
|
||||
return &ResultInfo{
|
||||
nq: nq,
|
||||
topK: topK,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithMetricType(metricType string) *ResultInfo {
|
||||
r.metricType = metricType
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithPkType(pkType schemapb.DataType) *ResultInfo {
|
||||
r.pkType = pkType
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithOffset(offset int64) *ResultInfo {
|
||||
r.offset = offset
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithGroupByField(groupByField int64) *ResultInfo {
|
||||
r.groupByFieldId = groupByField
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithGroupSize(groupSize int64) *ResultInfo {
|
||||
r.groupSize = groupSize
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) WithAdvance(advance bool) *ResultInfo {
|
||||
r.isAdvance = advance
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetNq() int64 {
|
||||
return r.nq
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetTopK() int64 {
|
||||
return r.topK
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetMetricType() string {
|
||||
return r.metricType
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetPkType() schemapb.DataType {
|
||||
return r.pkType
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetOffset() int64 {
|
||||
return r.offset
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetGroupByFieldId() int64 {
|
||||
return r.groupByFieldId
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetGroupSize() int64 {
|
||||
return r.groupSize
|
||||
}
|
||||
|
||||
func (r *ResultInfo) GetIsAdvance() bool {
|
||||
return r.isAdvance
|
||||
}
|
||||
|
||||
func (r *ResultInfo) SetMetricType(metricType string) {
|
||||
r.metricType = metricType
|
||||
}
|
Loading…
Reference in New Issue