fix: Add more checks to rank params (#29950)

issue: #29840 #29867
/kind bug

Signed-off-by: xige-16 <xi.ge@zilliz.com>

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/29928/head
xige-16 2024-01-17 20:28:58 +08:00 committed by GitHub
parent fa7cf587b0
commit 91aa81b4d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 2 deletions

View File

@ -109,10 +109,19 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
switch rankTypeMap[rankTypeStr] {
case rrfRankType:
k, ok := params[RRFParamsKey].(float64)
_, ok := params[RRFParamsKey]
if !ok {
return nil, errors.New(RRFParamsKey + " not found in rank_params")
}
var k float64
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
k = reflect.ValueOf(params[RRFParamsKey]).Float()
} else {
return nil, errors.New("The type of rank param k should be float")
}
if k <= 0 || k >= maxRRFParamsValue {
return nil, errors.New("The rank params k should be in range (0, 16384)")
}
log.Debug("rrf params", zap.Float64("k", k))
for i := range reqs {
res[i] = &rrfScorer{
@ -131,7 +140,16 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
case reflect.Slice:
rs := reflect.ValueOf(params[WeightsParamsKey])
for i := 0; i < rs.Len(); i++ {
weights = append(weights, float32(rs.Index(i).Interface().(float64)))
v := rs.Index(i).Elem()
if v.CanFloat() {
weight := v.Float()
if weight < 0 || weight > 1 {
return nil, errors.New("rank param weight should be in range [0, 1]")
}
weights = append(weights, float32(weight))
} else {
return nil, errors.New("The type of rank param weight should be float")
}
}
default:
return nil, errors.New("The weights param should be an array")

View File

@ -18,6 +18,45 @@ func TestRescorer(t *testing.T) {
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
})
t.Run("rrf without param", func(t *testing.T) {
params := make(map[string]float64)
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "k not found in rank_params")
})
t.Run("rrf param out of range", func(t *testing.T) {
params := make(map[string]float64)
params[RRFParamsKey] = -1
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.Error(t, err)
params[RRFParamsKey] = maxRRFParamsValue + 1
b, err = json.Marshal(params)
assert.NoError(t, err)
rankParams = []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.Error(t, err)
})
t.Run("rrf", func(t *testing.T) {
params := make(map[string]float64)
params[RRFParamsKey] = 61
@ -35,6 +74,36 @@ func TestRescorer(t *testing.T) {
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
})
t.Run("weights without param", func(t *testing.T) {
params := make(map[string][]float64)
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found in rank_params")
})
t.Run("weights out of range", func(t *testing.T) {
weights := []float64{1.2, 2.3}
params := make(map[string][]float64)
params[WeightsParamsKey] = weights
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
})
t.Run("weights", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string][]float64)

View File

@ -73,6 +73,7 @@ const (
InvertedIndexType = "INVERTED"
defaultRRFParamsValue = 60
maxRRFParamsValue = 16384
)
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))