fix: reduce incorrectly for group-by with offset(#30828) (#30882)

related: #30828

Signed-off-by: MrPresent-Han <chun.han@zilliz.com>
pull/31079/head
MrPresent-Han 2024-03-06 16:47:00 +08:00 committed by GitHub
parent fd17a5f050
commit d0eeea4b44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 469 additions and 176 deletions

View File

@ -0,0 +1,381 @@
package proxy
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type reduceSearchResultInfo struct {
subSearchResultData []*schemapb.SearchResultData
nq int64
topK int64
metricType string
pkType schemapb.DataType
offset int64
queryInfo *planpb.QueryInfo
}
func NewReduceSearchResultInfo(
subSearchResultData []*schemapb.SearchResultData,
nq int64,
topK int64,
metricType string,
pkType schemapb.DataType,
offset int64,
queryInfo *planpb.QueryInfo,
) *reduceSearchResultInfo {
return &reduceSearchResultInfo{
subSearchResultData: subSearchResultData,
nq: nq,
topK: topK,
metricType: metricType,
pkType: pkType,
offset: offset,
queryInfo: queryInfo,
}
}
func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) {
if reduceInfo.queryInfo.GroupByFieldId > 0 {
return reduceSearchResultDataWithGroupBy(ctx,
reduceInfo.subSearchResultData,
reduceInfo.nq,
reduceInfo.topK,
reduceInfo.metricType,
reduceInfo.pkType,
reduceInfo.offset)
}
return reduceSearchResultDataNoGroupBy(ctx,
reduceInfo.subSearchResultData,
reduceInfo.nq,
reduceInfo.topK,
reduceInfo.metricType,
reduceInfo.pkType,
reduceInfo.offset)
}
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
limit := topk - offset
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, limit),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, limit),
},
}
default:
return nil, errors.New("unsupported pk type")
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
var (
skipDupCnt int64
realTopK int64 = -1
)
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
idSet = make(map[interface{}]struct{})
groupByValSet = make(map[interface{}]struct{})
)
// keep limit results
for j = 0; j < limit; {
// From all the sub-query result sets of the i-th query vector,
// find the sub-query result set index of the score j-th data,
// and the index of the data in schemapb.SearchResultData
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.Scores[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
if groupByVal == nil {
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
"there must be sth wrong on queryNode side")
}
// remove duplicates
if _, ok := idSet[id]; !ok {
_, groupByValExist := groupByValSet[groupByVal]
if !groupByValExist {
groupByValSet[groupByVal] = struct{}{}
if int64(len(groupByValSet)) <= offset {
continue
// skip offset groups
}
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
typeutil.AppendPKs(ret.Results.Ids, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
j++
} else {
// skip entity with same groupby
skipDupCnt++
}
} else {
// skip entity with same id
skipDupCnt++
}
cursors[subSearchIdx]++
}
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
}
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
if skipDupCnt > 0 {
log.Ctx(ctx).Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
limit := topk - offset
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, limit),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, limit),
},
}
default:
return nil, errors.New("unsupported pk type")
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
var (
skipDupCnt int64
realTopK int64 = -1
)
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
idSet = make(map[interface{}]struct{})
)
// skip offset results
for k := int64(0); k < offset; k++ {
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
cursors[subSearchIdx]++
}
// keep limit results
for j = 0; j < limit; {
// From all the sub-query result sets of the i-th query vector,
// find the sub-query result set index of the score j-th data,
// and the index of the data in schemapb.SearchResultData
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
// remove duplicatessds
if _, ok := idSet[id]; !ok {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
typeutil.AppendPKs(ret.Results.Ids, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
j++
} else {
// skip entity with same id
skipDupCnt++
}
cursors[subSearchIdx]++
}
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
}
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
if skipDupCnt > 0 {
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}

View File

@ -118,6 +118,7 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.queryInfo = queryInfo
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)

View File

@ -70,6 +70,7 @@ type searchTask struct {
node types.ProxyComponent
lb LBPolicy
queryChannelsTs map[string]Timestamp
queryInfo *planpb.QueryInfo
}
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
@ -443,7 +444,8 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err
}
t.result, err = reduceSearchResultData(ctx, validSearchResults, Nq, Topk, MetricType, primaryFieldSchema.DataType, t.offset)
t.result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, Nq, Topk,
MetricType, primaryFieldSchema.DataType, t.offset, t.queryInfo))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return err
@ -751,173 +753,6 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
return subSearchIdx, resultDataIdx
}
func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
limit := topk - offset
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, limit),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, limit),
},
}
default:
return nil, errors.New("unsupported pk type")
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
var (
skipDupCnt int64
realTopK int64 = -1
)
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
idSet = make(map[interface{}]struct{})
groupByValSet = make(map[interface{}]struct{})
)
// skip offset results
for k := int64(0); k < offset; k++ {
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
cursors[subSearchIdx]++
}
// keep limit results
for j = 0; j < limit; {
// From all the sub-query result sets of the i-th query vector,
// find the sub-query result set index of the score j-th data,
// and the index of the data in schemapb.SearchResultData
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.Scores[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
// remove duplicates
if _, ok := idSet[id]; !ok {
groupByValExist := false
if groupByVal != nil {
_, groupByValExist = groupByValSet[groupByVal]
}
if !groupByValExist {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
typeutil.AppendPKs(ret.Results.Ids, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
if groupByVal != nil {
groupByValSet[groupByVal] = struct{}{}
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
}
j++
}
} else {
// skip entity with same id
skipDupCnt++
}
cursors[subSearchIdx]++
}
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
}
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
if skipDupCnt > 0 {
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
type rangeSearchParams struct {
radius float64
rangeFilter float64

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
@ -1526,9 +1527,13 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
results = append(results, r)
}
queryInfo := &planpb.QueryInfo{
GroupByFieldId: -1,
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
reduced, err := reduceSearchResult(context.TODO(),
NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
@ -1577,10 +1582,10 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
[]int64{},
},
}
for _, test := range lessThanLimitTests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk,
metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
@ -1604,7 +1609,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
results = append(results, r)
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, 0)
queryInfo := &planpb.QueryInfo{
GroupByFieldId: -1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(
results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo))
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
@ -1630,8 +1640,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
results = append(results, r)
}
queryInfo := &planpb.QueryInfo{
GroupByFieldId: -1,
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_VarChar, 0)
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results,
nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo))
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
@ -1700,8 +1714,11 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
}
results = append(results, result)
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topK, metric.L2, schemapb.DataType_Int64, 0)
queryInfo := &planpb.QueryInfo{
GroupByFieldId: 1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
schemapb.DataType_Int64, 0, queryInfo))
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
@ -1713,6 +1730,63 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
}
}
func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
var (
nq int64 = 1
limit int64 = 5
offset int64 = 5
)
ids := [][]int64{
{1, 3, 5, 7, 9},
{2, 4, 6, 8, 10},
}
scores := [][]float32{
{10, 8, 6, 4, 2},
{9, 7, 5, 3, 1},
}
groupByValuesArr := [][]int64{
{1, 3, 5, 7, 9},
{2, 4, 6, 8, 10},
}
expectedIDs := []int64{6, 7, 8, 9, 10}
expectedScores := []float32{-5, -4, -3, -2, -1}
expectedGroupByValues := []int64{6, 7, 8, 9, 10}
var results []*schemapb.SearchResultData
for j := range ids {
result := getSearchResultData(nq, limit+offset)
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
result.Scores = scores[j]
result.Topks = []int64{limit}
result.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: groupByValuesArr[j],
},
},
},
},
}
results = append(results, result)
}
queryInfo := &planpb.QueryInfo{
GroupByFieldId: 1,
}
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
schemapb.DataType_Int64, offset, queryInfo))
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
assert.EqualValues(t, expectedIDs, resultIDs)
assert.EqualValues(t, expectedScores, resultScores)
assert.EqualValues(t, expectedGroupByValues, resultGroupByValues)
assert.NoError(t, err)
}
func TestSearchTask_ErrExecute(t *testing.T) {
var (
err error
@ -2367,7 +2441,9 @@ func TestSearchTask_Requery(t *testing.T) {
qt.resultBuf.Insert(&internalpb.SearchResults{
SlicedBlob: bytes,
})
qt.queryInfo = &planpb.QueryInfo{
GroupByFieldId: -1,
}
err = qt.PostExecute(ctx)
t.Logf("err = %s", err)
assert.Error(t, err)