fix: add checks for bm25 k1 and b in sparse index checker (#36907)

issue: https://github.com/milvus-io/milvus/issues/36883,
https://github.com/milvus-io/milvus/issues/35853

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
pull/35970/head
Buqian Zheng 2024-10-16 19:43:24 +08:00 committed by GitHub
parent e5948bd039
commit 51f13ba7cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 108 additions and 0 deletions

View File

@ -45,6 +45,9 @@ const (
// Sparse Index Param
SparseDropRatioBuild = "drop_ratio_build"
BM25K1 = "bm25_k1"
BM25B = "bm25_b"
MaxBitmapCardinalityLimit = 1000
)

View File

@ -29,6 +29,22 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
}
}
bm25K1Str, exist := params[BM25K1]
if exist {
bm25K1, err := strconv.ParseFloat(bm25K1Str, 64)
if err != nil || bm25K1 < 0 || bm25K1 > 3 {
return fmt.Errorf("invalid bm25_k1: %s, must be in range [0, 3]", bm25K1Str)
}
}
bm25BStr, exist := params[BM25B]
if exist {
bm25B, err := strconv.ParseFloat(bm25BStr, 64)
if err != nil || bm25B < 0 || bm25B > 1 {
return fmt.Errorf("invalid bm25_b: %s, must be in range [0, 1]", bm25BStr)
}
}
return nil
}

View File

@ -0,0 +1,89 @@
package indexparamcheck
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
validParams := map[string]string{
Metric: "IP",
}
invalidParams := map[string]string{
Metric: "L2",
}
c := newSparseFloatVectorBaseChecker()
t.Run("valid metric", func(t *testing.T) {
err := c.StaticCheck(validParams)
assert.NoError(t, err)
})
t.Run("invalid metric", func(t *testing.T) {
err := c.StaticCheck(invalidParams)
assert.Error(t, err)
})
}
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
SparseDropRatioBuild: "0.5",
BM25K1: "1.5",
BM25B: "0.5",
}
invalidDropRatio := map[string]string{
SparseDropRatioBuild: "1.5",
}
invalidBM25K1 := map[string]string{
BM25K1: "3.5",
}
invalidBM25B := map[string]string{
BM25B: "1.5",
}
c := newSparseFloatVectorBaseChecker()
t.Run("valid params", func(t *testing.T) {
err := c.CheckTrain(validParams)
assert.NoError(t, err)
})
t.Run("invalid drop ratio", func(t *testing.T) {
err := c.CheckTrain(invalidDropRatio)
assert.Error(t, err)
})
t.Run("invalid BM25K1", func(t *testing.T) {
err := c.CheckTrain(invalidBM25K1)
assert.Error(t, err)
})
t.Run("invalid BM25B", func(t *testing.T) {
err := c.CheckTrain(invalidBM25B)
assert.Error(t, err)
})
}
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
c := newSparseFloatVectorBaseChecker()
t.Run("valid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
err := c.CheckValidDataType(field)
assert.NoError(t, err)
})
t.Run("invalid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
err := c.CheckValidDataType(field)
assert.Error(t, err)
})
}