feat: supporing hybrid search group_by (#35982)

related: #35096

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/36113/head
Chun Han 2024-09-08 17:09:04 +08:00 committed by GitHub
parent 62f4a6a112
commit e480b103bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 983 additions and 334 deletions

View File

@ -170,7 +170,8 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
{
Key: common.IndexTypeKey,
Value: "HYBRID",
}},
},
},
Timestamp: 0,
IsAutoIndex: true,
UserIndexParams: userIndexParams,
@ -205,7 +206,6 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
assert.Equal(t, newIndexParams[0].Key, common.IndexTypeKey)
assert.Equal(t, newIndexParams[0].Value, "INVERTED")
})
}
func TestMeta_CanCreateIndex(t *testing.T) {

View File

@ -769,7 +769,7 @@ func (s *taskSchedulerSuite) scheduler(handler Handler) {
return nil
})
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil)
//catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
// catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
in := mocks.NewMockIndexNodeClient(s.T())
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)

View File

@ -63,7 +63,7 @@ func mergeSortMultipleSegments(ctx context.Context,
return nil, err
}
//SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
// SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
segmentReaders := make([]*SegmentDeserializeReader, len(binlogs))
for i, s := range binlogs {
var binlogBatchCount int

View File

@ -119,7 +119,6 @@ func (s *PriorityQueueSuite) PriorityQueueMergeSort() {
heap.Push(&pq, next)
}
}
}
func TestNewPriorityQueueSuite(t *testing.T) {

View File

@ -94,11 +94,8 @@ message SubSearchRequest {
int64 topk = 7;
int64 offset = 8;
string metricType = 9;
}
message ExtraSearchParam {
int64 group_by_field_id = 1;
int64 group_size = 2;
int64 group_by_field_id = 10;
int64 group_size = 11;
}
message SearchRequest {
@ -125,7 +122,8 @@ message SearchRequest {
bool is_advanced = 20;
int64 offset = 21;
common.ConsistencyLevel consistency_level = 22;
ExtraSearchParam extra_search_param = 23;
int64 group_by_field_id = 23;
int64 group_size = 24;
}
message SubSearchResults {

View File

@ -11,7 +11,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
@ -20,54 +20,137 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type reduceSearchResultInfo struct {
subSearchResultData []*schemapb.SearchResultData
nq int64
topK int64
metricType string
pkType schemapb.DataType
offset int64
queryInfo *planpb.QueryInfo
}
func NewReduceSearchResultInfo(
subSearchResultData []*schemapb.SearchResultData,
nq int64,
topK int64,
metricType string,
pkType schemapb.DataType,
offset int64,
queryInfo *planpb.QueryInfo,
) *reduceSearchResultInfo {
return &reduceSearchResultInfo{
subSearchResultData: subSearchResultData,
nq: nq,
topK: topK,
metricType: metricType,
pkType: pkType,
offset: offset,
queryInfo: queryInfo,
}
}
func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) {
if reduceInfo.queryInfo.GroupByFieldId > 0 {
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
if reduceInfo.GetGroupByFieldId() > 0 {
if reduceInfo.GetIsAdvance() {
// for hybrid search group by, we cannot reduce result for results from one single search path,
// because the final score has not been accumulated, also, offset cannot be applied
return reduceAdvanceGroupBY(ctx,
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
}
return reduceSearchResultDataWithGroupBy(ctx,
reduceInfo.subSearchResultData,
reduceInfo.nq,
reduceInfo.topK,
reduceInfo.metricType,
reduceInfo.pkType,
reduceInfo.offset,
reduceInfo.queryInfo.GroupSize)
subSearchResultData,
reduceInfo.GetNq(),
reduceInfo.GetTopK(),
reduceInfo.GetMetricType(),
reduceInfo.GetPkType(),
reduceInfo.GetOffset(),
reduceInfo.GetGroupSize())
}
return reduceSearchResultDataNoGroupBy(ctx,
reduceInfo.subSearchResultData,
reduceInfo.nq,
reduceInfo.topK,
reduceInfo.metricType,
reduceInfo.pkType,
reduceInfo.offset)
subSearchResultData,
reduceInfo.GetNq(),
reduceInfo.GetTopK(),
reduceInfo.GetMetricType(),
reduceInfo.GetPkType(),
reduceInfo.GetOffset())
}
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
nq int64, topK int64,
) (int64, int, error) {
var allSearchCount int64
var hitNum int
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
allSearchCount += sData.GetAllSearchCount()
hitNum += pkLength
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return allSearchCount, hitNum, err
}
}
return allSearchCount, hitNum, nil
}
func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
nq int64, topK int64, pkType schemapb.DataType, metricType string,
) (*milvuspb.SearchResults, error) {
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
// for advance group by, offset is not applied, so just return when there's only one channel
if len(subSearchResultData) == 1 {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: subSearchResultData[0],
}, nil
}
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topK,
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
var limit int64
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
limit = int64(hitNum)
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit)
}
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
// reducing nq * topk results
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
dataCount := int64(0)
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
subData := subSearchResultData[subIdx]
subPks := subData.GetIds()
subScores := subData.GetScores()
subGroupByVals := subData.GetGroupByFieldValue()
nqTopK := subData.Topks[nqIdx]
for i := int64(0); i < nqTopK; i++ {
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
pk := typeutil.GetPK(subPks, innerIdx)
score := subScores[innerIdx]
groupByVal := typeutil.GetData(subData.GetGroupByFieldValue(), int(innerIdx))
typeutil.AppendPKs(ret.Results.Ids, pk)
ret.Results.Scores = append(ret.Results.Scores, score)
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subGroupByVals.GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
dataCount += 1
}
}
ret.Results.Topks = append(ret.Results.Topks, dataCount)
}
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
type MilvusPKType interface{}
@ -109,37 +192,16 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, limit),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, limit),
},
}
default:
return nil, errors.New("unsupported pk type")
groupBound := groupSize * limit
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
return ret, nil
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
ret.Results.AllSearchCount += sData.GetAllSearchCount()
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
}
var (
@ -163,7 +225,6 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
groupBound := groupSize * limit
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
@ -298,36 +359,15 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, limit),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, limit),
},
}
default:
return nil, errors.New("unsupported pk type")
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
ret.Results.AllSearchCount += sData.GetAllSearchCount()
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
}
var (
@ -428,23 +468,215 @@ func rankSearchResultData(ctx context.Context,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
groupByFieldID int64,
groupSize int64,
groupScorer func(group *Group) error,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultData")
if groupByFieldID > 0 {
return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize)
}
return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults)
}
func compareKey(keyI interface{}, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
return keyI.(int64) < keyJ.(int64)
case string:
return keyI.(string) < keyJ.(string)
}
return false
}
func GetGroupScorer(scorerType string) (func(group *Group) error, error) {
switch scorerType {
case MaxScorer:
return func(group *Group) error {
group.finalScore = group.maxScore
return nil
}, nil
case SumScorer:
return func(group *Group) error {
group.finalScore = group.sumScore
return nil
}, nil
case AvgScorer:
return func(group *Group) error {
if len(group.idList) == 0 {
return merr.WrapErrParameterInvalid(1, len(group.idList),
"input group for score must have at least one id, must be sth wrong within code")
}
group.finalScore = group.sumScore / float32(len(group.idList))
return nil
}, nil
default:
return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
}
}
type Group struct {
idList []interface{}
scoreList []float32
groupVal interface{}
maxScore float32
sumScore float32
finalScore float32
}
func rankSearchResultDataByGroup(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
groupScorer func(group *Group) error,
groupSize int64,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset := params.offset
limit := params.limit
topk := limit + offset
roundDecimal := params.roundDecimal
log.Ctx(ctx).Debug("rankSearchResultData",
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
// in the context of group by, the meaning for offset/limit/top refers to related numbers of group
groupTopK := limit + offset
log.Ctx(ctx).Debug("rankSearchResultDataByGroup",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
ret := &milvuspb.SearchResults{
var ret *milvuspb.SearchResults
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
return ret, nil
}
totalCount := limit * groupSize
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
return ret, err
}
type accumulateIDGroupVal struct {
accumulatedScore float32
groupVal interface{}
}
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
}
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := 0
// milvus has limits for the value range of nq and limit
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
for i := 0; i < int(nq); i++ {
realTopK := int(result.GetResults().Topks[i])
for j := start; j < start+realTopK; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
if accumulatedScores[i][id] != nil {
accumulatedScores[i][id].accumulatedScore += scores[j]
} else {
accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal}
}
}
start += realTopK
}
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
// sort id by score
big := func(i, j int) bool {
scoreItemI := idSet[keys[i]]
scoreItemJ := idSet[keys[j]]
if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore {
return compareKey(keys[i], keys[j])
}
return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore
}
sort.Slice(keys, big)
// separate keys into buckets according to groupVal
buckets := make(map[interface{}]*Group)
for _, key := range keys {
scoreItem := idSet[key]
groupVal := scoreItem.groupVal
if buckets[groupVal] == nil {
buckets[groupVal] = &Group{
idList: make([]interface{}, 0),
scoreList: make([]float32, 0),
groupVal: groupVal,
}
}
if int64(len(buckets[groupVal].idList)) >= groupSize {
// only consider group size results in each group
continue
}
buckets[groupVal].idList = append(buckets[groupVal].idList, key)
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore)
if scoreItem.accumulatedScore > buckets[groupVal].maxScore {
buckets[groupVal].maxScore = scoreItem.accumulatedScore
}
buckets[groupVal].sumScore += scoreItem.accumulatedScore
}
if int64(len(buckets)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
groupList := make([]*Group, len(buckets))
idx := 0
for _, group := range buckets {
groupScorer(group)
groupList[idx] = group
idx += 1
}
sort.Slice(groupList, func(i, j int) bool {
if groupList[i].finalScore == groupList[j].finalScore {
if len(groupList[i].idList) == len(groupList[j].idList) {
// if final score and size of group are both equal
// choose the group with smaller first key
// here, it's guaranteed all group having at least one id in the idList
return compareKey(groupList[i].idList[0], groupList[j].idList[0])
}
// choose the larger group when scores are equal
return len(groupList[i].idList) > len(groupList[j].idList)
}
return groupList[i].finalScore > groupList[j].finalScore
})
if int64(len(groupList)) > groupTopK {
groupList = groupList[:groupTopK]
}
returnedRowNum := 0
for index := int(offset); index < len(groupList); index++ {
group := groupList[index]
for i, score := range group.scoreList {
// idList and scoreList must have same length
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
}
returnedRowNum += len(group.idList)
}
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
}
return ret, nil
}
func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
@ -455,22 +687,54 @@ func rankSearchResultData(ctx context.Context,
Topks: []int64{},
},
}
}
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0),
Data: make([]int64, 0, capacity),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0),
Data: make([]string, 0, capacity),
},
}
default:
return nil, errors.New("unsupported pk type")
return errors.New("unsupported pk type")
}
return nil
}
func rankSearchResultDataByPk(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
topk := limit + offset
log.Ctx(ctx).Debug("rankSearchResultDataByPk",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
var ret *milvuspb.SearchResults
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
return ret, nil
}
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
// []map[id]score
@ -503,20 +767,10 @@ func rankSearchResultData(ctx context.Context,
continue
}
compareKeys := func(keyI, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
return keyI.(int64) < keyJ.(int64)
case string:
return keyI.(string) < keyJ.(string)
}
return false
}
// sort id by score
big := func(i, j int) bool {
if idSet[keys[i]] == idSet[keys[j]] {
return compareKeys(keys[i], keys[j])
return compareKey(keys[i], keys[j])
}
return idSet[keys[i]] > idSet[keys[j]]
}

View File

@ -0,0 +1,133 @@
package proxy
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type SearchReduceUtilTestSuite struct {
suite.Suite
}
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
var searchResultData1 *schemapb.SearchResultData
var searchResultData2 *schemapb.SearchResultData
{
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
searchResultData1 = &schemapb.SearchResultData{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"7", "5", "4", "2", "3", "6", "1", "9", "8"},
},
},
},
Topks: []int64{9},
Scores: []float32{0.6, 0.53, 0.52, 0.43, 0.41, 0.33, 0.30, 0.27, 0.22},
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
}
}
{
groupFieldValue := []string{"www", "aaa", "ccc", "www", "www", "ccc", "aaa", "ccc", "aaa"}
searchResultData2 = &schemapb.SearchResultData{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"17", "15", "14", "12", "13", "16", "11", "19", "18"},
},
},
},
Topks: []int64{9},
Scores: []float32{0.7, 0.43, 0.32, 0.32, 0.31, 0.31, 0.30, 0.30, 0.30},
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
}
}
searchResults := []*milvuspb.SearchResults{
{Results: searchResultData1},
{Results: searchResultData2},
}
nq := int64(1)
limit := int64(3)
offset := int64(0)
roundDecimal := int64(1)
groupSize := int64(3)
groupByFieldId := int64(101)
rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal}
{
// test for sum group scorer
scorerType := "sum"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for max group scorer
scorerType := "max"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for avg group scorer
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for offset for ranking group
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankParams.offset = 2
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for offset exceeding the count of final groups
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankParams.offset = 4
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{}, rankedRes.GetResults().GetScores())
}
{
// test for invalid group scorer
scorerType := "xxx"
groupScorer, err := GetGroupScorer(scorerType)
struts.Error(err)
struts.Nil(groupScorer)
}
}
func TestSearchReduceUtilTestSuite(t *testing.T) {
suite.Run(t, new(SearchReduceUtilTestSuite))
}

View File

@ -310,13 +310,15 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
}
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
copy(searchParams, req.GetRankParams())
ret := &milvuspb.SearchRequest{
Base: req.GetBase(),
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
PartitionNames: req.GetPartitionNames(),
OutputFields: req.GetOutputFields(),
SearchParams: req.GetRankParams(),
SearchParams: searchParams,
TravelTimestamp: req.GetTravelTimestamp(),
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
Nq: 0,

View File

@ -42,6 +42,12 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
SumScorer string = "sum"
MaxScorer string = "max"
AvgScorer string = "avg"
)
const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
@ -49,6 +55,7 @@ const (
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
RankGroupScorer = "rank_group_scorer"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
@ -76,8 +77,9 @@ type searchTask struct {
queryInfos []*planpb.QueryInfo
relatedDataSize int64
reScorers []reScorer
rankParams *rankParams
reScorers []reScorer
rankParams *rankParams
groupScorer func(group *Group) error
}
func (t *searchTask) CanSkipAllocTimestamp() bool {
@ -339,10 +341,9 @@ func setQueryInfoIfMvEnable(queryInfo *planpb.QueryInfo, t *searchTask, plan *pl
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
defer sp.End()
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
@ -351,9 +352,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
if err != nil {
return err
}
if queryInfo.GetGroupByFieldId() != -1 {
return errors.New("not support search_group_by operation in the hybrid search")
}
internalSubReq := &internalpb.SubSearchRequest{
Dsl: subReq.GetDsl(),
PlaceholderGroup: subReq.GetPlaceholderGroup(),
@ -364,6 +363,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
Topk: queryInfo.GetTopk(),
Offset: offset,
MetricType: queryInfo.GetMetricType(),
GroupByFieldId: queryInfo.GetGroupByFieldId(),
GroupSize: queryInfo.GetGroupSize(),
}
// set PartitionIDs for sub search
@ -403,6 +404,11 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
if len(t.queryInfos) > 0 {
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
}
// used for requery
if t.partitionKeyMode {
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
@ -413,6 +419,18 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
return err
}
// set up groupScorer for hybridsearch+groupBy
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
if err != nil {
groupScorerStr = MaxScorer
}
groupScorer, err := GetGroupScorer(groupScorerStr)
if err != nil {
return err
}
t.groupScorer = groupScorer
return nil
}
@ -461,7 +479,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.queryInfos = append(t.queryInfos, queryInfo)
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize}
t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId
t.SearchRequest.GroupSize = queryInfo.GroupSize
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
@ -554,7 +573,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
return nil
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) {
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
metricType := ""
if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType()
@ -585,8 +604,8 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
return nil, err
}
var result *milvuspb.SearchResults
result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK,
metricType, primaryFieldSchema.DataType, offset, queryInfo))
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()).
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return nil, err
@ -647,7 +666,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
@ -656,7 +674,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
if len(internalResults) >= 1 {
metricType = internalResults[0].GetMetricType()
}
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index])
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
if err != nil {
return err
}
@ -667,13 +685,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
t.rankParams,
primaryFieldSchema.GetDataType(),
multipleMilvusResults)
multipleMilvusResults,
t.SearchRequest.GetGroupByFieldId(),
t.SearchRequest.GetGroupSize(),
t.groupScorer)
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
} else {
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0])
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false)
if err != nil {
return err
}
@ -914,7 +935,7 @@ func decodeSearchResults(ctx context.Context, searchResults []*internalpb.Search
return results, nil
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
@ -922,7 +943,6 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
pkHitNum := typeutil.GetSizeOfIDs(data.GetIds())
if len(data.Scores) != pkHitNum {
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
len(data.Scores), pkHitNum)

View File

@ -39,6 +39,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -1247,7 +1248,8 @@ func Test_checkSearchResultData(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk)
pkLength := typeutil.GetSizeOfIDs(test.args.data.GetIds())
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk, pkLength)
if test.wantErr {
assert.Error(t, err)
@ -1522,8 +1524,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResult(context.TODO(),
NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).
WithOffset(test.offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
@ -1574,8 +1577,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
}
for _, test := range lessThanLimitTests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk,
metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(test.offset).
WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
@ -1603,9 +1607,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
GroupByFieldId: -1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(
results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
@ -1633,9 +1636,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
queryInfo := &planpb.QueryInfo{
GroupByFieldId: -1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results,
nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_VarChar).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
@ -1708,8 +1710,8 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
GroupByFieldId: 1,
GroupSize: 1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
schemapb.DataType_Int64, 0, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
@ -1768,8 +1770,8 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
GroupByFieldId: 1,
GroupSize: 1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
schemapb.DataType_Int64, offset, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, limit+offset).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
@ -1842,8 +1844,9 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
GroupByFieldId: 1,
GroupSize: 2,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
schemapb.DataType_Int64, 0, queryInfo))
reduced, err := reduceSearchResult(context.TODO(), results,
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
@ -1855,6 +1858,188 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
}
}
func TestTaskSearch_reduceAdvanceSearchGroupBy(t *testing.T) {
groupByField := int64(101)
nq := int64(1)
subSearchResultData := make([]*schemapb.SearchResultData, 0)
topK := int64(3)
{
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
tops := []int64{9}
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
result1 := &schemapb.SearchResultData{
Scores: scores,
TopK: topK,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
NumQueries: nq,
Topks: tops,
GroupByFieldValue: groupByVals,
}
subSearchResultData = append(subSearchResultData, result1)
}
{
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33, 27}
tops := []int64{9}
groupFieldValue := []string{"xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
result2 := &schemapb.SearchResultData{
TopK: topK,
Scores: scores,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
Topks: tops,
NumQueries: nq,
GroupByFieldValue: groupByVals,
}
subSearchResultData = append(subSearchResultData, result2)
}
groupSize := int64(3)
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
assert.NoError(t, err)
// reduce_advance_groupby will only merge results from different delegator without reducing any result
assert.Equal(t, 18, len(reducedRes.GetResults().Ids.GetIntId().Data))
assert.Equal(t, 18, len(reducedRes.GetResults().GetScores()))
assert.Equal(t, 18, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
assert.Equal(t, []int64{18}, reducedRes.GetResults().GetTopks())
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37, 17, 15, 16, 21, 32, 24, 41, 33, 27}, reducedRes.GetResults().Ids.GetIntId().Data)
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43, 0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}, reducedRes.GetResults().GetScores())
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa", "xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
func TestTaskSearch_reduceAdvanceSearchGroupByShortCut(t *testing.T) {
groupByField := int64(101)
nq := int64(1)
subSearchResultData := make([]*schemapb.SearchResultData, 0)
topK := int64(3)
{
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
tops := []int64{9}
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
result1 := &schemapb.SearchResultData{
Scores: scores,
TopK: topK,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
NumQueries: nq,
Topks: tops,
GroupByFieldValue: groupByVals,
}
subSearchResultData = append(subSearchResultData, result1)
}
groupSize := int64(3)
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
assert.NoError(t, err)
// reduce_advance_groupby will only merge results from different delegator without reducing any result
assert.Equal(t, 9, len(reducedRes.GetResults().Ids.GetIntId().Data))
assert.Equal(t, 9, len(reducedRes.GetResults().GetScores()))
assert.Equal(t, 9, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
assert.Equal(t, []int64{9}, reducedRes.GetResults().GetTopks())
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}, reducedRes.GetResults().Ids.GetIntId().Data)
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}, reducedRes.GetResults().GetScores())
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
func TestTaskSearch_reduceAdvanceSearchGroupByMultipleNq(t *testing.T) {
groupByField := int64(101)
nq := int64(2)
subSearchResultData := make([]*schemapb.SearchResultData, 0)
topK := int64(2)
groupSize := int64(2)
{
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.51, 0.5, 0.45, 0.43}
ids := []int64{7, 5, 6, 11, 14, 31, 23, 37}
tops := []int64{4, 4}
groupFieldValue := []string{"ccc", "bbb", "ccc", "bbb", "aaa", "xxx", "xxx", "aaa"}
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
result1 := &schemapb.SearchResultData{
Scores: scores,
TopK: topK,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
NumQueries: nq,
Topks: tops,
GroupByFieldValue: groupByVals,
}
subSearchResultData = append(subSearchResultData, result1)
}
{
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51}
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33}
tops := []int64{4, 4}
groupFieldValue := []string{"ddd", "bbb", "ddd", "bbb", "rrr", "sss", "rrr", "sss"}
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
result2 := &schemapb.SearchResultData{
TopK: topK,
Scores: scores,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
Topks: tops,
NumQueries: nq,
GroupByFieldValue: groupByVals,
}
subSearchResultData = append(subSearchResultData, result2)
}
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
assert.NoError(t, err)
// reduce_advance_groupby will only merge results from different delegator without reducing any result
assert.Equal(t, 16, len(reducedRes.GetResults().Ids.GetIntId().Data))
assert.Equal(t, 16, len(reducedRes.GetResults().GetScores()))
assert.Equal(t, 16, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
assert.Equal(t, []int64{8, 8}, reducedRes.GetResults().GetTopks())
assert.Equal(t, []int64{7, 5, 6, 11, 17, 15, 16, 21, 14, 31, 23, 37, 32, 24, 41, 33}, reducedRes.GetResults().Ids.GetIntId().Data)
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.83, 0.72, 0.72, 0.65, 0.51, 0.5, 0.45, 0.43, 0.63, 0.55, 0.52, 0.51}, reducedRes.GetResults().GetScores())
assert.Equal(t, []string{"ccc", "bbb", "ccc", "bbb", "ddd", "bbb", "ddd", "bbb", "aaa", "xxx", "xxx", "aaa", "rrr", "sss", "rrr", "sss"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
fmt.Println(reducedRes.GetResults().Ids.GetIntId().Data)
fmt.Println(reducedRes.GetResults().GetScores())
fmt.Println(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
func TestSearchTask_ErrExecute(t *testing.T) {
var (
err error

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -332,6 +333,8 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
Username: req.GetReq().GetUsername(),
IsAdvanced: false,
GroupByFieldId: subReq.GetGroupByFieldId(),
GroupSize: subReq.GetGroupSize(),
}
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
@ -350,14 +353,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return nil, err
}
return segments.ReduceSearchResults(ctx,
return segments.ReduceSearchOnQueryNode(ctx,
results,
segments.NewReduceInfo(searchReq.Req.GetNq(),
searchReq.Req.GetTopk(),
searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(),
searchReq.Req.GetExtraSearchParam().GetGroupSize(),
searchReq.Req.GetMetricType()),
)
reduce.NewReduceSearchResultInfo(searchReq.GetReq().GetNq(),
searchReq.GetReq().GetTopk()).WithMetricType(searchReq.GetReq().GetMetricType()).
WithGroupByField(searchReq.GetReq().GetGroupByFieldId()).
WithGroupSize(searchReq.GetReq().GetGroupSize()))
})
futures[index] = future
}
@ -376,12 +377,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
}
results[i] = result
}
var ret *internalpb.SearchResults
ret, err = segments.MergeToAdvancedResults(ctx, results)
if err != nil {
return nil, err
}
return []*internalpb.SearchResults{ret}, nil
return results, nil
}
return sd.search(ctx, req, sealed, growing)
}

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
@ -384,16 +385,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
req.GetSegmentIDs(),
))
var resp *internalpb.SearchResults
if req.GetReq().GetIsAdvanced() {
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
} else {
resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(),
req.Req.GetTopk(),
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
req.Req.GetExtraSearchParam().GetGroupSize(),
req.Req.GetMetricType()))
}
resp, err := segments.ReduceSearchOnQueryNode(ctx, results,
reduce.NewReduceSearchResultInfo(req.GetReq().GetNq(),
req.GetReq().GetTopk()).WithMetricType(req.GetReq().GetMetricType()).WithGroupByField(req.GetReq().GetGroupByFieldId()).
WithGroupSize(req.GetReq().GetGroupByFieldId()).WithAdvance(req.GetReq().GetIsAdvanced()))
if err != nil {
return nil, err
}

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/reduce"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -42,7 +43,14 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) {
func ReduceSearchOnQueryNode(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
if info.GetIsAdvance() {
return ReduceAdvancedSearchResults(ctx, results)
}
return ReduceSearchResults(ctx, results, info)
}
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
return result != nil && result.GetSlicedBlob() != nil
})
@ -60,8 +68,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
channelsMvcc[ch] = ts
}
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
if info.metricType == "" {
info.metricType = r.MetricType
if info.GetMetricType() == "" {
info.SetMetricType(r.MetricType)
}
}
log := log.Ctx(ctx)
@ -86,7 +94,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
log.Warn("shard leader reduce errors", zap.Error(err))
return nil, err
}
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType)
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.GetNq(), info.GetTopK(), info.GetMetricType())
if err != nil {
log.Warn("shard leader encode search result errors", zap.Error(err))
return nil, err
@ -115,7 +123,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
return searchResults, nil
}
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) {
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults")
defer sp.End()
@ -129,53 +137,14 @@ func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.Sear
IsAdvanced: true,
}
for _, result := range results {
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
if !result.GetIsAdvanced() {
continue
}
// we just append here, no need to split subResult and reduce
// defer this reduce to proxy
searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...)
searchResults.NumQueries = result.GetNumQueries()
}
searchResults.ChannelsMvcc = channelsMvcc
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
return result.GetCostAggregation(), true
}
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
return result.GetCostAggregation(), true
}
return nil, false
})
searchResults.CostAggregation = mergeRequestCost(requestCosts)
if searchResults.CostAggregation == nil {
searchResults.CostAggregation = &internalpb.CostAggregation{}
}
searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
return searchResults, nil
}
func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
searchResults := &internalpb.SearchResults{
IsAdvanced: true,
}
channelsMvcc := make(map[string]uint64)
relatedDataSize := int64(0)
for index, result := range results {
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
searchResults.NumQueries = result.GetNumQueries()
// we just append here, no need to split subResult and reduce
// defer this reduce to proxy
// defer this reduction to proxy
subResult := &internalpb.SubSearchResults{
MetricType: result.GetMetricType(),
NumQueries: result.GetNumQueries(),
@ -185,7 +154,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes
SlicedOffset: result.GetSlicedOffset(),
ReqIndex: int64(index),
}
searchResults.NumQueries = result.GetNumQueries()
searchResults.SubResults = append(searchResults.SubResults, subResult)
}
searchResults.ChannelsMvcc = channelsMvcc

View File

@ -29,7 +29,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -886,6 +888,42 @@ func (suite *ResultSuite) TestSort() {
}, result.FieldsData[9].GetScalars().GetArrayData().GetData())
}
func (suite *ResultSuite) TestReduceSearchOnQueryNode() {
results := make([]*internalpb.SearchResults, 0)
metricType := metric.IP
nq := int64(1)
topK := int64(1)
mockBlob := []byte{65, 66, 67, 65, 66, 67}
{
subRes1 := &internalpb.SearchResults{
MetricType: metricType,
NumQueries: nq,
TopK: topK,
SlicedBlob: mockBlob,
}
results = append(results, subRes1)
}
{
subRes2 := &internalpb.SearchResults{
MetricType: metricType,
NumQueries: nq,
TopK: topK,
SlicedBlob: mockBlob,
}
results = append(results, subRes2)
}
reducedRes, err := ReduceSearchOnQueryNode(context.Background(), results, reduce.NewReduceSearchResultInfo(nq, topK).
WithMetricType(metricType).WithPkType(schemapb.DataType_Int8).WithAdvance(true))
suite.NoError(err)
suite.Equal(2, len(reducedRes.GetSubResults()))
subRes1 := reducedRes.GetSubResults()[0]
suite.Equal(metricType, subRes1.GetMetricType())
suite.Equal(nq, subRes1.GetNumQueries())
suite.Equal(topK, subRes1.GetTopK())
suite.Equal(mockBlob, subRes1.GetSlicedBlob())
}
func TestResult_MergeRequestCost(t *testing.T) {
costs := []*internalpb.CostAggregation{
{

View File

@ -8,39 +8,28 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type ReduceInfo struct {
nq int64
topK int64
groupByFieldID int64
groupSize int64
metricType string
}
func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo {
return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric}
}
type SearchReduce interface {
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error)
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error)
}
type SearchCommonReduce struct{}
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
defer sp.End()
log := log.Ctx(ctx)
if len(searchResultData) == 0 {
return &schemapb.SearchResultData{
NumQueries: info.nq,
TopK: info.topK,
NumQueries: info.GetNq(),
TopK: info.GetTopK(),
FieldsData: make([]*schemapb.FieldData, 0),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{},
@ -48,8 +37,8 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
}, nil
}
ret := &schemapb.SearchResultData{
NumQueries: info.nq,
TopK: info.topK,
NumQueries: info.GetNq(),
TopK: info.GetTopK(),
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{},
@ -59,7 +48,7 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
resultOffsets := make([][]int64, len(searchResultData))
for i := 0; i < len(searchResultData); i++ {
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
for j := int64(1); j < info.nq; j++ {
for j := int64(1); j < info.GetNq(); j++ {
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
}
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
@ -68,11 +57,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
var skipDupCnt int64
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for i := int64(0); i < info.nq; i++ {
for i := int64(0); i < info.GetNq(); i++ {
offsets := make([]int64, len(searchResultData))
idSet := make(map[interface{}]struct{})
var j int64
for j = 0; j < info.topK; {
for j = 0; j < info.GetTopK(); {
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
if sel == -1 {
break
@ -113,15 +102,15 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
type SearchGroupByReduce struct{}
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
defer sp.End()
log := log.Ctx(ctx)
if len(searchResultData) == 0 {
return &schemapb.SearchResultData{
NumQueries: info.nq,
TopK: info.topK,
NumQueries: info.GetNq(),
TopK: info.GetTopK(),
FieldsData: make([]*schemapb.FieldData, 0),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{},
@ -129,8 +118,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
}, nil
}
ret := &schemapb.SearchResultData{
NumQueries: info.nq,
TopK: info.topK,
NumQueries: info.GetNq(),
TopK: info.GetTopK(),
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
Scores: make([]float32, 0),
Ids: &schemapb.IDs{},
@ -140,7 +129,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
resultOffsets := make([][]int64, len(searchResultData))
for i := 0; i < len(searchResultData); i++ {
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
for j := int64(1); j < info.nq; j++ {
for j := int64(1); j < info.GetNq(); j++ {
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
}
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
@ -149,13 +138,13 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
var filteredCount int64
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
groupSize := info.groupSize
groupSize := info.GetGroupSize()
if groupSize <= 0 {
groupSize = 1
}
groupBound := info.topK * groupSize
groupBound := info.GetTopK() * groupSize
for i := int64(0); i < info.nq; i++ {
for i := int64(0); i < info.GetNq(); i++ {
offsets := make([]int64, len(searchResultData))
idSet := make(map[interface{}]struct{})
@ -178,7 +167,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
}
groupCount := groupByValueMap[groupByVal]
if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK {
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
// exceed the limit for group count, filter this entity
filteredCount++
} else if groupCount >= groupSize {
@ -219,8 +208,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
return ret, nil
}
func InitSearchReducer(info *ReduceInfo) SearchReduce {
if info.groupByFieldID > 0 {
func InitSearchReducer(info *reduce.ResultInfo) SearchReduce {
if info.GetGroupByFieldId() > 0 {
return &SearchGroupByReduce{}
}
return &SearchCommonReduce{}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -28,7 +29,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -47,7 +48,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -96,7 +97,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -140,7 +141,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -184,7 +185,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -228,7 +229,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
@ -239,7 +240,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
suite.Run("reduce_group_by_empty_input", func() {
dataArray := make([]*schemapb.SearchResultData, 0)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)

View File

@ -753,69 +753,41 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return resp, nil
}
toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
runningGp, runningCtx := errgroup.WithContext(ctx)
for i, ch := range req.GetDmlChannels() {
ch := ch
req := &querypb.SearchRequest{
Req: req.Req,
DmlChannels: []string{ch},
SegmentIDs: req.SegmentIDs,
Scope: req.Scope,
TotalChannelNum: req.TotalChannelNum,
}
i := i
runningGp.Go(func() error {
ret, err := node.searchChannel(runningCtx, req, ch)
if err != nil {
return err
}
if err := merr.Error(ret.GetStatus()); err != nil {
return err
}
toReduceResults[i] = ret
return nil
})
if len(req.GetDmlChannels()) != 1 {
err := merr.WrapErrParameterInvalid(1, len(req.GetDmlChannels()), "count of channel to be searched should only be 1, wrong code")
resp.Status = merr.Status(err)
log.Warn("got wrong number of channels to be searched", zap.Error(err))
return resp, nil
}
if err := runningGp.Wait(); err != nil {
ch := req.GetDmlChannels()[0]
channelReq := &querypb.SearchRequest{
Req: req.Req,
DmlChannels: []string{ch},
SegmentIDs: req.SegmentIDs,
Scope: req.Scope,
TotalChannelNum: req.TotalChannelNum,
}
ret, err := node.searchChannel(ctx, channelReq, ch)
if err != nil {
resp.Status = merr.Status(err)
return resp, nil
}
tr.RecordSpan()
var result *internalpb.SearchResults
var err2 error
if req.GetReq().GetIsAdvanced() {
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
} else {
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(),
req.Req.GetTopk(),
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
req.Req.GetExtraSearchParam().GetGroupSize(),
req.Req.GetMetricType()))
}
if err2 != nil {
log.Warn("failed to reduce search results", zap.Error(err2))
resp.Status = merr.Status(err2)
return resp, nil
}
result.Status = merr.Success()
ret.Status = merr.Success()
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
Observe(float64(reduceLatency.Milliseconds()))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel).
Add(float64(proto.Size(req)))
if result.GetCostAggregation() != nil {
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
if ret.GetCostAggregation() != nil {
ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return result, nil
return ret, nil
}
// only used for delegator query segments from worker

View File

@ -0,0 +1,92 @@
package reduce
import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type ResultInfo struct {
nq int64
topK int64
metricType string
pkType schemapb.DataType
offset int64
groupByFieldId int64
groupSize int64
isAdvance bool
}
func NewReduceSearchResultInfo(
nq int64,
topK int64,
) *ResultInfo {
return &ResultInfo{
nq: nq,
topK: topK,
}
}
func (r *ResultInfo) WithMetricType(metricType string) *ResultInfo {
r.metricType = metricType
return r
}
func (r *ResultInfo) WithPkType(pkType schemapb.DataType) *ResultInfo {
r.pkType = pkType
return r
}
func (r *ResultInfo) WithOffset(offset int64) *ResultInfo {
r.offset = offset
return r
}
func (r *ResultInfo) WithGroupByField(groupByField int64) *ResultInfo {
r.groupByFieldId = groupByField
return r
}
func (r *ResultInfo) WithGroupSize(groupSize int64) *ResultInfo {
r.groupSize = groupSize
return r
}
func (r *ResultInfo) WithAdvance(advance bool) *ResultInfo {
r.isAdvance = advance
return r
}
func (r *ResultInfo) GetNq() int64 {
return r.nq
}
func (r *ResultInfo) GetTopK() int64 {
return r.topK
}
func (r *ResultInfo) GetMetricType() string {
return r.metricType
}
func (r *ResultInfo) GetPkType() schemapb.DataType {
return r.pkType
}
func (r *ResultInfo) GetOffset() int64 {
return r.offset
}
func (r *ResultInfo) GetGroupByFieldId() int64 {
return r.groupByFieldId
}
func (r *ResultInfo) GetGroupSize() int64 {
return r.groupSize
}
func (r *ResultInfo) GetIsAdvance() bool {
return r.isAdvance
}
func (r *ResultInfo) SetMetricType(metricType string) {
r.metricType = metricType
}