mirror of https://github.com/milvus-io/milvus.git
fix: add checks for bm25 k1 and b in sparse index checker (#36907)
issue: https://github.com/milvus-io/milvus/issues/36883, https://github.com/milvus-io/milvus/issues/35853 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>pull/35970/head
parent
e5948bd039
commit
51f13ba7cf
|
@ -45,6 +45,9 @@ const (
|
|||
// Sparse Index Param
|
||||
SparseDropRatioBuild = "drop_ratio_build"
|
||||
|
||||
BM25K1 = "bm25_k1"
|
||||
BM25B = "bm25_b"
|
||||
|
||||
MaxBitmapCardinalityLimit = 1000
|
||||
)
|
||||
|
||||
|
|
|
@ -29,6 +29,22 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
|
|||
}
|
||||
}
|
||||
|
||||
bm25K1Str, exist := params[BM25K1]
|
||||
if exist {
|
||||
bm25K1, err := strconv.ParseFloat(bm25K1Str, 64)
|
||||
if err != nil || bm25K1 < 0 || bm25K1 > 3 {
|
||||
return fmt.Errorf("invalid bm25_k1: %s, must be in range [0, 3]", bm25K1Str)
|
||||
}
|
||||
}
|
||||
|
||||
bm25BStr, exist := params[BM25B]
|
||||
if exist {
|
||||
bm25B, err := strconv.ParseFloat(bm25BStr, 64)
|
||||
if err != nil || bm25B < 0 || bm25B > 1 {
|
||||
return fmt.Errorf("invalid bm25_b: %s, must be in range [0, 1]", bm25BStr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
Metric: "IP",
|
||||
}
|
||||
|
||||
invalidParams := map[string]string{
|
||||
Metric: "L2",
|
||||
}
|
||||
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
|
||||
t.Run("valid metric", func(t *testing.T) {
|
||||
err := c.StaticCheck(validParams)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid metric", func(t *testing.T) {
|
||||
err := c.StaticCheck(invalidParams)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
SparseDropRatioBuild: "0.5",
|
||||
BM25K1: "1.5",
|
||||
BM25B: "0.5",
|
||||
}
|
||||
|
||||
invalidDropRatio := map[string]string{
|
||||
SparseDropRatioBuild: "1.5",
|
||||
}
|
||||
|
||||
invalidBM25K1 := map[string]string{
|
||||
BM25K1: "3.5",
|
||||
}
|
||||
|
||||
invalidBM25B := map[string]string{
|
||||
BM25B: "1.5",
|
||||
}
|
||||
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
|
||||
t.Run("valid params", func(t *testing.T) {
|
||||
err := c.CheckTrain(validParams)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid drop ratio", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidDropRatio)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid BM25K1", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidBM25K1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid BM25B", func(t *testing.T) {
|
||||
err := c.CheckTrain(invalidBM25B)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
||||
c := newSparseFloatVectorBaseChecker()
|
||||
|
||||
t.Run("valid data type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
|
||||
err := c.CheckValidDataType(field)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid data type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
||||
err := c.CheckValidDataType(field)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue