mirror of https://github.com/milvus-io/milvus.git
enhance: [2.5] weighted reranker to allow skip score normalization (#40905)
issue: https://github.com/milvus-io/milvus/issues/40836 pr: https://github.com/milvus-io/milvus/pull/40903 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>pull/41071/head
parent
a906466d8f
commit
c9a354d436
|
@ -77,12 +77,18 @@ func (rs *rrfScorer) scorerType() rankType {
|
|||
|
||||
type weightedScorer struct {
|
||||
baseScorer
|
||||
weight float32
|
||||
weight float32
|
||||
normScore bool
|
||||
}
|
||||
|
||||
type activateFunc func(float32) float32
|
||||
|
||||
func (ws *weightedScorer) getActivateFunc() activateFunc {
|
||||
if !ws.normScore {
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
}
|
||||
mUpper := strings.ToUpper(ws.getMetricType())
|
||||
isCosine := mUpper == strings.ToUpper(metric.COSINE)
|
||||
isIP := mUpper == strings.ToUpper(metric.IP)
|
||||
|
@ -190,6 +196,11 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
|
|||
if _, ok := params[WeightsParamsKey]; !ok {
|
||||
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
|
||||
}
|
||||
// normalize scores by default
|
||||
normScore := true
|
||||
if _, ok := params[NormScoreKey]; ok {
|
||||
normScore = params[NormScoreKey].(bool)
|
||||
}
|
||||
weights := make([]float32, 0)
|
||||
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
|
||||
case reflect.Slice:
|
||||
|
@ -210,7 +221,7 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
|
|||
return nil, errors.New("The weights param should be an array")
|
||||
}
|
||||
|
||||
log.Debug("weights params", zap.Any("weights", weights))
|
||||
log.Debug("weights params", zap.Any("weights", weights), zap.Bool("norm_score", normScore))
|
||||
if reqCnt != len(weights) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
|
@ -219,7 +230,8 @@ func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyVal
|
|||
baseScorer: baseScorer{
|
||||
scorerName: "weighted",
|
||||
},
|
||||
weight: weights[i],
|
||||
weight: weights[i],
|
||||
normScore: normScore,
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -104,9 +104,29 @@ func TestRescorer(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
|
||||
})
|
||||
|
||||
t.Run("weights with norm_score false", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string]interface{})
|
||||
params[WeightsParamsKey] = weights
|
||||
params[NormScoreKey] = false
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
||||
assert.False(t, rescorers[0].(*weightedScorer).normScore)
|
||||
})
|
||||
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string][]float64)
|
||||
params := make(map[string]interface{})
|
||||
params[WeightsParamsKey] = weights
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
|
@ -120,5 +140,7 @@ func TestRescorer(t *testing.T) {
|
|||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
||||
// normalize scores by default
|
||||
assert.True(t, rescorers[0].(*weightedScorer).normScore)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -120,6 +120,7 @@ const (
|
|||
RankParamsKey = "params"
|
||||
RRFParamsKey = "k"
|
||||
WeightsParamsKey = "weights"
|
||||
NormScoreKey = "norm_score"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
|
|
Loading…
Reference in New Issue