fix: reorder hybridsearch result for L2 metric (#30739)

#30694

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/30787/head
zhagnlu 2024-02-26 14:18:55 +08:00 committed by GitHub
parent ece9d273a7
commit a0531b72aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 4 deletions

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -395,8 +396,10 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
return err
}
metricType := ""
t.queryChannelsTs = make(map[string]uint64)
for _, r := range t.resultBuf.Collect() {
metricType = r.GetResults()[0].GetMetricType()
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
}
@ -416,6 +419,7 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
t.result, err = rankSearchResultData(ctx, 1,
t.rankParams,
primaryFieldSchema.GetDataType(),
metricType,
t.multipleRecallResults.Collect())
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
@ -468,6 +472,7 @@ func rankSearchResultData(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
metricType string,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultData")
@ -483,7 +488,8 @@ func rankSearchResultData(ctx context.Context,
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
zap.Int64("limit", limit),
zap.String("metric type", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
@ -546,9 +552,18 @@ func rankSearchResultData(ctx context.Context,
}
// sort id by score
sort.Slice(keys, func(i, j int) bool {
return idSet[keys[i]] >= idSet[keys[j]]
})
var less func(i, j int) bool
if metric.PositivelyRelated(metricType) {
less = func(i, j int) bool {
return idSet[keys[i]] > idSet[keys[j]]
}
} else {
less = func(i, j int) bool {
return idSet[keys[i]] < idSet[keys[j]]
}
}
sort.Slice(keys, less)
if int64(len(keys)) > topk {
keys = keys[:topk]