mirror of https://github.com/milvus-io/milvus.git
feat: added more checks for function creation check (#36766)
issue: https://github.com/milvus-io/milvus/issues/35853 * BM25 Function now takes no params, k1, b should be passed via index params * support BM25 full text search when metric type is not present in search request * add more strict validation with functions at collection creation time Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>pull/36790/head
parent
16b533cbf0
commit
383350c120
|
@ -136,24 +136,11 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string
|
|||
}
|
||||
|
||||
switch cit.functionSchema.GetType() {
|
||||
case schemapb.FunctionType_Unknown:
|
||||
return fmt.Errorf("unknown function type encountered")
|
||||
|
||||
case schemapb.FunctionType_BM25:
|
||||
for _, kv := range cit.functionSchema.GetParams() {
|
||||
switch kv.GetKey() {
|
||||
case "bm25_k1":
|
||||
if _, ok := indexParamsMap["bm25_k1"]; !ok {
|
||||
indexParamsMap["bm25_k1"] = kv.GetValue()
|
||||
}
|
||||
case "bm25_b":
|
||||
if _, ok := indexParamsMap["bm25_b"]; !ok {
|
||||
indexParamsMap["bm25_b"] = kv.GetValue()
|
||||
}
|
||||
case "bm25_avgdl":
|
||||
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
|
||||
indexParamsMap["bm25_avgdl"] = kv.GetValue()
|
||||
}
|
||||
}
|
||||
}
|
||||
// set default avgdl
|
||||
// set default BM25 params if not provided in index params
|
||||
if _, ok := indexParamsMap["bm25_k1"]; !ok {
|
||||
indexParamsMap["bm25_k1"] = "1.2"
|
||||
}
|
||||
|
@ -165,8 +152,15 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string
|
|||
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
|
||||
indexParamsMap["bm25_avgdl"] = "100"
|
||||
}
|
||||
|
||||
if metricType, ok := indexParamsMap["metric_type"]; !ok {
|
||||
indexParamsMap["metric_type"] = "BM25"
|
||||
} else if metricType != "BM25" {
|
||||
return fmt.Errorf("index metric type of BM25 function output field must be BM25, got %s", metricType)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("parse unknown type function params to index")
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -192,11 +186,6 @@ func (cit *createIndexTask) parseIndexParams() error {
|
|||
}
|
||||
}
|
||||
|
||||
// fill index param for bm25 function
|
||||
if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -342,6 +331,11 @@ func (cit *createIndexTask) parseIndexParams() error {
|
|||
}
|
||||
}
|
||||
|
||||
// fill index param for Functions
|
||||
if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
indexType, exist := indexParamsMap[common.IndexTypeKey]
|
||||
if !exist {
|
||||
return fmt.Errorf("IndexType not specified")
|
||||
|
|
|
@ -620,10 +620,14 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
|||
})
|
||||
usedOutputField := typeutil.NewSet[string]()
|
||||
usedFunctionName := typeutil.NewSet[string]()
|
||||
// validate function
|
||||
|
||||
for _, function := range coll.GetFunctions() {
|
||||
if err := checkFunctionBasicParams(function); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if usedFunctionName.Contain(function.GetName()) {
|
||||
return fmt.Errorf("duplicate function name %s", function.GetName())
|
||||
return fmt.Errorf("duplicate function name: %s", function.GetName())
|
||||
}
|
||||
|
||||
usedFunctionName.Insert(function.GetName())
|
||||
|
@ -631,13 +635,15 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
|||
for _, name := range function.GetInputFieldNames() {
|
||||
inputField, ok := nameMap[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("function input field not found %s", function.InputFieldNames)
|
||||
return fmt.Errorf("function input field not found: %s", name)
|
||||
}
|
||||
if inputField.GetNullable() {
|
||||
return fmt.Errorf("function input field cannot be nullable: function %s, field %s", function.GetName(), inputField.GetName())
|
||||
}
|
||||
inputFields = append(inputFields, inputField)
|
||||
}
|
||||
|
||||
err := checkFunctionInputField(function, inputFields)
|
||||
if err != nil {
|
||||
if err := checkFunctionInputField(function, inputFields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -645,12 +651,25 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
|||
for i, name := range function.GetOutputFieldNames() {
|
||||
outputField, ok := nameMap[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("function output field not found %s", function.InputFieldNames)
|
||||
return fmt.Errorf("function output field not found: %s", name)
|
||||
}
|
||||
|
||||
if outputField.GetIsPrimaryKey() {
|
||||
return fmt.Errorf("function output field cannot be primary key: function %s, field %s", function.GetName(), outputField.GetName())
|
||||
}
|
||||
|
||||
if outputField.GetIsPartitionKey() || outputField.GetIsClusteringKey() {
|
||||
return fmt.Errorf("function output field cannot be partition key or clustering key: function %s, field %s", function.GetName(), outputField.GetName())
|
||||
}
|
||||
|
||||
if outputField.GetNullable() {
|
||||
return fmt.Errorf("function output field cannot be nullable: function %s, field %s", function.GetName(), outputField.GetName())
|
||||
}
|
||||
|
||||
outputField.IsFunctionOutput = true
|
||||
outputFields[i] = outputField
|
||||
if usedOutputField.Contain(name) {
|
||||
return fmt.Errorf("duplicate function output %s", name)
|
||||
return fmt.Errorf("duplicate function output field: function %s, field %s", function.GetName(), name)
|
||||
}
|
||||
usedOutputField.Insert(name)
|
||||
}
|
||||
|
@ -658,10 +677,6 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
|||
if err := checkFunctionOutputField(function, outputFields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkFunctionParams(function); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -670,19 +685,11 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
|
|||
switch function.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
if len(fields) != 1 {
|
||||
return fmt.Errorf("bm25 only need 1 output field, but now %d", len(fields))
|
||||
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
|
||||
}
|
||||
|
||||
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
|
||||
return fmt.Errorf("bm25 only need sparse embedding output field, but now %s", fields[0].DataType.String())
|
||||
}
|
||||
|
||||
if fields[0].GetIsPrimaryKey() {
|
||||
return fmt.Errorf("bm25 output field can't be primary key")
|
||||
}
|
||||
|
||||
if fields[0].GetIsPartitionKey() || fields[0].GetIsClusteringKey() {
|
||||
return fmt.Errorf("bm25 output field can't be partition key or cluster key field")
|
||||
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check output field for unknown function type")
|
||||
|
@ -694,12 +701,12 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
|
|||
switch function.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar {
|
||||
return fmt.Errorf("only one VARCHAR input field is allowed for a BM25 Function, got %d field with type %s",
|
||||
return fmt.Errorf("BM25 function input field must be a VARCHAR field, got %d field with type %s",
|
||||
len(fields), fields[0].DataType.String())
|
||||
}
|
||||
h := typeutil.CreateFieldSchemaHelper(fields[0])
|
||||
if !h.EnableTokenizer() {
|
||||
return fmt.Errorf("BM25 input field must set enable_tokenizer to true")
|
||||
return fmt.Errorf("BM25 function input field must set enable_tokenizer to true")
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -708,46 +715,40 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkFunctionParams(function *schemapb.FunctionSchema) error {
|
||||
func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
|
||||
if function.GetName() == "" {
|
||||
return fmt.Errorf("function name cannot be empty")
|
||||
}
|
||||
if len(function.GetInputFieldNames()) == 0 {
|
||||
return fmt.Errorf("function input field names cannot be empty, function: %s", function.GetName())
|
||||
}
|
||||
if len(function.GetOutputFieldNames()) == 0 {
|
||||
return fmt.Errorf("function output field names cannot be empty, function: %s", function.GetName())
|
||||
}
|
||||
for _, input := range function.GetInputFieldNames() {
|
||||
if input == "" {
|
||||
return fmt.Errorf("function input field name cannot be empty string, function: %s", function.GetName())
|
||||
}
|
||||
// if input occurs more than once, error
|
||||
if lo.Count(function.GetInputFieldNames(), input) > 1 {
|
||||
return fmt.Errorf("each function input field should be used exactly once in the same function, function: %s, input field: %s", function.GetName(), input)
|
||||
}
|
||||
}
|
||||
for _, output := range function.GetOutputFieldNames() {
|
||||
if output == "" {
|
||||
return fmt.Errorf("function output field name cannot be empty string, function: %s", function.GetName())
|
||||
}
|
||||
if lo.Count(function.GetInputFieldNames(), output) > 0 {
|
||||
return fmt.Errorf("a single field cannot be both input and output in the same function, function: %s, field: %s", function.GetName(), output)
|
||||
}
|
||||
if lo.Count(function.GetOutputFieldNames(), output) > 1 {
|
||||
return fmt.Errorf("each function output field should be used exactly once in the same function, function: %s, output field: %s", function.GetName(), output)
|
||||
}
|
||||
}
|
||||
switch function.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
for _, kv := range function.GetParams() {
|
||||
switch kv.GetKey() {
|
||||
case "bm25_k1":
|
||||
k1, err := strconv.ParseFloat(kv.GetValue(), 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bm25_k1 value, %w", err)
|
||||
}
|
||||
|
||||
if k1 < 0 || k1 > 3 {
|
||||
return fmt.Errorf("bm25_k1 must in [0,3] but now %f", k1)
|
||||
}
|
||||
|
||||
case "bm25_b":
|
||||
b, err := strconv.ParseFloat(kv.GetValue(), 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bm25_b value, %w", err)
|
||||
}
|
||||
|
||||
if b < 0 || b > 1 {
|
||||
return fmt.Errorf("bm25_b must in [0,1] but now %f", b)
|
||||
}
|
||||
|
||||
case "bm25_avgdl":
|
||||
avgdl, err := strconv.ParseFloat(kv.GetValue(), 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bm25_avgdl value, %w", err)
|
||||
}
|
||||
|
||||
if avgdl <= 0 {
|
||||
return fmt.Errorf("bm25_avgdl must large than zero but now %f", avgdl)
|
||||
}
|
||||
|
||||
case "tokenizer_params":
|
||||
// TODO ADD tokenizer check
|
||||
default:
|
||||
return fmt.Errorf("invalid function params, key: %s, value:%s", kv.GetKey(), kv.GetValue())
|
||||
}
|
||||
if len(function.GetParams()) != 0 {
|
||||
return fmt.Errorf("BM25 function accepts no params")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check function params with unknown function type")
|
||||
|
|
|
@ -2637,3 +2637,420 @@ func TestValidateLoadFieldsList(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFunction(t *testing.T) {
|
||||
t.Run("Valid function schema", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - duplicate function names", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate function name")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - input field not found", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"non_existent_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "input field not found")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - output field not found", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"non_existent_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "output field not found")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - nullable input field", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}, Nullable: true},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "function input field cannot be nullable")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - output field is primary key", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsPrimaryKey: true},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "function output field cannot be primary key")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - output field is partition key", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsPartitionKey: true},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "function output field cannot be partition key or clustering key")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - output field is clustering key", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsClusteringKey: true},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "function output field cannot be partition key or clustering key")
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - nullable output field", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, Nullable: true},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "function output field cannot be nullable")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionInputField(t *testing.T) {
|
||||
t.Run("Valid BM25 function input", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid BM25 function input - wrong data type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid BM25 function input - tokenizer not enabled", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "false"}},
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid BM25 function input - multiple fields", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
|
||||
},
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Unknown function type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionOutputField(t *testing.T) {
|
||||
t.Run("Valid BM25 function output", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
},
|
||||
}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid BM25 function output - wrong data type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid BM25 function output - multiple fields", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
},
|
||||
{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Unknown function type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionBasicParams(t *testing.T) {
|
||||
t.Run("Valid function", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "validFunction",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1", "input2"},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty function name", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1"},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty input field names", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "emptyInputs",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty output field names", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "emptyOutputs",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1"},
|
||||
OutputFieldNames: []string{},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty input field name", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "emptyInputName",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1", ""},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Duplicate input field names", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "duplicateInputs",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1", "input1"},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty output field name", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "emptyOutputName",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1"},
|
||||
OutputFieldNames: []string{"output1", ""},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Input field used as output", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "inputAsOutput",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"field1", "field2"},
|
||||
OutputFieldNames: []string{"field1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Duplicate output field names", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "duplicateOutputs",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input1"},
|
||||
OutputFieldNames: []string{"output1", "output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -143,7 +143,7 @@ type shardDelegator struct {
|
|||
|
||||
// fieldId -> functionRunner map for search function field
|
||||
functionRunners map[UniqueID]function.FunctionRunner
|
||||
hasBM25Field bool
|
||||
isBM25Field map[UniqueID]bool
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
|
@ -245,7 +245,7 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||
}
|
||||
|
||||
// build idf for bm25 search
|
||||
if req.GetReq().GetMetricType() == metric.BM25 {
|
||||
if req.GetReq().GetMetricType() == metric.BM25 || (req.GetReq().GetMetricType() == metric.EMPTY && sd.isBM25Field[req.GetReq().GetFieldId()]) {
|
||||
avgdl, err := sd.buildBM25IDF(req.GetReq())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -908,6 +908,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
||||
excludedSegments: excludedSegments,
|
||||
functionRunners: make(map[int64]function.FunctionRunner),
|
||||
isBM25Field: make(map[int64]bool),
|
||||
}
|
||||
|
||||
for _, tf := range collection.Schema().GetFunctions() {
|
||||
|
@ -918,7 +919,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
}
|
||||
sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner
|
||||
if tf.GetType() == schemapb.FunctionType_BM25 {
|
||||
sd.hasBM25Field = true
|
||||
sd.isBM25Field[tf.OutputFieldIds[0]] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -385,7 +385,7 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
|
|||
|
||||
for _, segment := range loaded {
|
||||
sd.pkOracle.Register(segment, paramtable.GetNodeID())
|
||||
if sd.hasBM25Field {
|
||||
if len(sd.isBM25Field) > 0 {
|
||||
sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing)
|
||||
}
|
||||
}
|
||||
|
@ -485,7 +485,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||
})
|
||||
|
||||
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
|
||||
if sd.hasBM25Field {
|
||||
if len(sd.isBM25Field) > 0 {
|
||||
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
|
||||
if err != nil {
|
||||
log.Warn("failed to load bm25 stats for segment", zap.Error(err))
|
||||
|
|
|
@ -68,14 +68,17 @@ func NewInsertDataWithCap(schema *schemapb.CollectionSchema, cap int, withFuncti
|
|||
|
||||
for _, field := range schema.Fields {
|
||||
if field.IsPrimaryKey && field.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("primary key field not support nullable")
|
||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("primary key field should not be nullable (field: %s)", field.Name))
|
||||
}
|
||||
if field.IsPartitionKey && field.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("partition key field not support nullable")
|
||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("partition key field should not be nullable (field: %s)", field.Name))
|
||||
}
|
||||
if field.IsFunctionOutput {
|
||||
if field.IsPrimaryKey || field.IsPartitionKey {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("function output field should not be primary key or partition key")
|
||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("function output field should not be primary key or partition key (field: %s)", field.Name))
|
||||
}
|
||||
if field.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("function output field should not be nullable (field: %s)", field.Name))
|
||||
}
|
||||
if !withFunctionOutput {
|
||||
continue
|
||||
|
|
|
@ -38,4 +38,6 @@ const (
|
|||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
|
||||
BM25 MetricType = "BM25"
|
||||
|
||||
EMPTY MetricType = ""
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue