mirror of https://github.com/milvus-io/milvus.git
fix: reorder hybridsearch result for L2 metric (#30739)
#30694 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>pull/30787/head
parent
ece9d273a7
commit
a0531b72aa
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -395,8 +396,10 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
metricType := ""
|
||||
t.queryChannelsTs = make(map[string]uint64)
|
||||
for _, r := range t.resultBuf.Collect() {
|
||||
metricType = r.GetResults()[0].GetMetricType()
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
t.queryChannelsTs[ch] = ts
|
||||
}
|
||||
|
@ -416,6 +419,7 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
|||
t.result, err = rankSearchResultData(ctx, 1,
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
metricType,
|
||||
t.multipleRecallResults.Collect())
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
|
@ -468,6 +472,7 @@ func rankSearchResultData(ctx context.Context,
|
|||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
metricType string,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
||||
|
@ -483,7 +488,8 @@ func rankSearchResultData(ctx context.Context,
|
|||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metric type", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
|
@ -546,9 +552,18 @@ func rankSearchResultData(ctx context.Context,
|
|||
}
|
||||
|
||||
// sort id by score
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return idSet[keys[i]] >= idSet[keys[j]]
|
||||
})
|
||||
var less func(i, j int) bool
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
less = func(i, j int) bool {
|
||||
return idSet[keys[i]] > idSet[keys[j]]
|
||||
}
|
||||
} else {
|
||||
less = func(i, j int) bool {
|
||||
return idSet[keys[i]] < idSet[keys[j]]
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(keys, less)
|
||||
|
||||
if int64(len(keys)) > topk {
|
||||
keys = keys[:topk]
|
||||
|
|
Loading…
Reference in New Issue