enhance: [2.5] weighted reranker to allow skip score normalization (#40905)

issue: https://github.com/milvus-io/milvus/issues/40836
pr: https://github.com/milvus-io/milvus/pull/40903

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
pull/41071/head
Buqian Zheng 2025-04-02 17:30:22 +08:00 committed by GitHub
parent a906466d8f
commit c9a354d436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 4 deletions

View File

@ -77,12 +77,18 @@ func (rs *rrfScorer) scorerType() rankType {
type weightedScorer struct {
baseScorer
weight float32
weight float32
normScore bool
}
type activateFunc func(float32) float32
func (ws *weightedScorer) getActivateFunc() activateFunc {
if !ws.normScore {
return func(distance float32) float32 {
return distance
}
}
mUpper := strings.ToUpper(ws.getMetricType())
isCosine := mUpper == strings.ToUpper(metric.COSINE)
isIP := mUpper == strings.ToUpper(metric.IP)
@ -190,6 +196,11 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
if _, ok := params[WeightsParamsKey]; !ok {
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
}
// normalize scores by default
normScore := true
if _, ok := params[NormScoreKey]; ok {
normScore = params[NormScoreKey].(bool)
}
weights := make([]float32, 0)
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
case reflect.Slice:
@ -210,7 +221,7 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
return nil, errors.New("The weights param should be an array")
}
log.Debug("weights params", zap.Any("weights", weights))
log.Debug("weights params", zap.Any("weights", weights), zap.Bool("norm_score", normScore))
if reqCnt != len(weights) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
}
@ -219,7 +230,8 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
baseScorer: baseScorer{
scorerName: "weighted",
},
weight: weights[i],
weight: weights[i],
normScore: normScore,
}
}
default:

View File

@ -104,9 +104,29 @@ func TestRescorer(t *testing.T) {
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
})
t.Run("weights with norm_score false", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string]interface{})
params[WeightsParamsKey] = weights
params[NormScoreKey] = false
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
assert.False(t, rescorers[0].(*weightedScorer).normScore)
})
t.Run("weights", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string][]float64)
params := make(map[string]interface{})
params[WeightsParamsKey] = weights
b, err := json.Marshal(params)
assert.NoError(t, err)
@ -120,5 +140,7 @@ func TestRescorer(t *testing.T) {
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
// normalize scores by default
assert.True(t, rescorers[0].(*weightedScorer).normScore)
})
}

View File

@ -120,6 +120,7 @@ const (
RankParamsKey = "params"
RRFParamsKey = "k"
WeightsParamsKey = "weights"
NormScoreKey = "norm_score"
)
type task interface {