mirror of https://github.com/milvus-io/milvus.git
fix: Add more checks to rank params (#29950)
issue: #29840 #29867 /kind bug Signed-off-by: xige-16 <xi.ge@zilliz.com> Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/29928/head
parent
fa7cf587b0
commit
91aa81b4d7
|
@ -109,10 +109,19 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||
|
||||
switch rankTypeMap[rankTypeStr] {
|
||||
case rrfRankType:
|
||||
k, ok := params[RRFParamsKey].(float64)
|
||||
_, ok := params[RRFParamsKey]
|
||||
if !ok {
|
||||
return nil, errors.New(RRFParamsKey + " not found in rank_params")
|
||||
}
|
||||
var k float64
|
||||
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
||||
k = reflect.ValueOf(params[RRFParamsKey]).Float()
|
||||
} else {
|
||||
return nil, errors.New("The type of rank param k should be float")
|
||||
}
|
||||
if k <= 0 || k >= maxRRFParamsValue {
|
||||
return nil, errors.New("The rank params k should be in range (0, 16384)")
|
||||
}
|
||||
log.Debug("rrf params", zap.Float64("k", k))
|
||||
for i := range reqs {
|
||||
res[i] = &rrfScorer{
|
||||
|
@ -131,7 +140,16 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||
case reflect.Slice:
|
||||
rs := reflect.ValueOf(params[WeightsParamsKey])
|
||||
for i := 0; i < rs.Len(); i++ {
|
||||
weights = append(weights, float32(rs.Index(i).Interface().(float64)))
|
||||
v := rs.Index(i).Elem()
|
||||
if v.CanFloat() {
|
||||
weight := v.Float()
|
||||
if weight < 0 || weight > 1 {
|
||||
return nil, errors.New("rank param weight should be in range [0, 1]")
|
||||
}
|
||||
weights = append(weights, float32(weight))
|
||||
} else {
|
||||
return nil, errors.New("The type of rank param weight should be float")
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("The weights param should be an array")
|
||||
|
|
|
@ -18,6 +18,45 @@ func TestRescorer(t *testing.T) {
|
|||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
})
|
||||
|
||||
t.Run("rrf without param", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "k not found in rank_params")
|
||||
})
|
||||
|
||||
t.Run("rrf param out of range", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = -1
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.Error(t, err)
|
||||
|
||||
params[RRFParamsKey] = maxRRFParamsValue + 1
|
||||
b, err = json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams = []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("rrf", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = 61
|
||||
|
@ -35,6 +74,36 @@ func TestRescorer(t *testing.T) {
|
|||
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
|
||||
})
|
||||
|
||||
t.Run("weights without param", func(t *testing.T) {
|
||||
params := make(map[string][]float64)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found in rank_params")
|
||||
})
|
||||
|
||||
t.Run("weights out of range", func(t *testing.T) {
|
||||
weights := []float64{1.2, 2.3}
|
||||
params := make(map[string][]float64)
|
||||
params[WeightsParamsKey] = weights
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
|
||||
})
|
||||
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string][]float64)
|
||||
|
|
|
@ -73,6 +73,7 @@ const (
|
|||
InvertedIndexType = "INVERTED"
|
||||
|
||||
defaultRRFParamsValue = 60
|
||||
maxRRFParamsValue = 16384
|
||||
)
|
||||
|
||||
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
|
||||
|
|
Loading…
Reference in New Issue