feat: added more checks for function creation check (#36766)

issue: https://github.com/milvus-io/milvus/issues/35853

* BM25 Function now takes no params, k1, b should be passed via index
params
* support BM25 full text search when metric type is not present in
search request
* add more strict validation with functions at collection creation time

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
pull/36790/head
Buqian Zheng 2024-10-13 17:43:22 +08:00 committed by GitHub
parent 16b533cbf0
commit 383350c120
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 510 additions and 92 deletions

View File

@ -136,24 +136,11 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string
}
switch cit.functionSchema.GetType() {
case schemapb.FunctionType_Unknown:
return fmt.Errorf("unknown function type encountered")
case schemapb.FunctionType_BM25:
for _, kv := range cit.functionSchema.GetParams() {
switch kv.GetKey() {
case "bm25_k1":
if _, ok := indexParamsMap["bm25_k1"]; !ok {
indexParamsMap["bm25_k1"] = kv.GetValue()
}
case "bm25_b":
if _, ok := indexParamsMap["bm25_b"]; !ok {
indexParamsMap["bm25_b"] = kv.GetValue()
}
case "bm25_avgdl":
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
indexParamsMap["bm25_avgdl"] = kv.GetValue()
}
}
}
// set default avgdl
// set default BM25 params if not provided in index params
if _, ok := indexParamsMap["bm25_k1"]; !ok {
indexParamsMap["bm25_k1"] = "1.2"
}
@ -165,8 +152,15 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
indexParamsMap["bm25_avgdl"] = "100"
}
if metricType, ok := indexParamsMap["metric_type"]; !ok {
indexParamsMap["metric_type"] = "BM25"
} else if metricType != "BM25" {
return fmt.Errorf("index metric type of BM25 function output field must be BM25, got %s", metricType)
}
default:
return fmt.Errorf("parse unknown type function params to index")
return nil
}
return nil
@ -192,11 +186,6 @@ func (cit *createIndexTask) parseIndexParams() error {
}
}
// fill index param for bm25 function
if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil {
return err
}
if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil {
return err
}
@ -342,6 +331,11 @@ func (cit *createIndexTask) parseIndexParams() error {
}
}
// fill index param for Functions
if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil {
return err
}
indexType, exist := indexParamsMap[common.IndexTypeKey]
if !exist {
return fmt.Errorf("IndexType not specified")

View File

@ -620,10 +620,14 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
})
usedOutputField := typeutil.NewSet[string]()
usedFunctionName := typeutil.NewSet[string]()
// validate function
for _, function := range coll.GetFunctions() {
if err := checkFunctionBasicParams(function); err != nil {
return err
}
if usedFunctionName.Contain(function.GetName()) {
return fmt.Errorf("duplicate function name %s", function.GetName())
return fmt.Errorf("duplicate function name: %s", function.GetName())
}
usedFunctionName.Insert(function.GetName())
@ -631,13 +635,15 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
for _, name := range function.GetInputFieldNames() {
inputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function input field not found %s", function.InputFieldNames)
return fmt.Errorf("function input field not found: %s", name)
}
if inputField.GetNullable() {
return fmt.Errorf("function input field cannot be nullable: function %s, field %s", function.GetName(), inputField.GetName())
}
inputFields = append(inputFields, inputField)
}
err := checkFunctionInputField(function, inputFields)
if err != nil {
if err := checkFunctionInputField(function, inputFields); err != nil {
return err
}
@ -645,12 +651,25 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
for i, name := range function.GetOutputFieldNames() {
outputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function output field not found %s", function.InputFieldNames)
return fmt.Errorf("function output field not found: %s", name)
}
if outputField.GetIsPrimaryKey() {
return fmt.Errorf("function output field cannot be primary key: function %s, field %s", function.GetName(), outputField.GetName())
}
if outputField.GetIsPartitionKey() || outputField.GetIsClusteringKey() {
return fmt.Errorf("function output field cannot be partition key or clustering key: function %s, field %s", function.GetName(), outputField.GetName())
}
if outputField.GetNullable() {
return fmt.Errorf("function output field cannot be nullable: function %s, field %s", function.GetName(), outputField.GetName())
}
outputField.IsFunctionOutput = true
outputFields[i] = outputField
if usedOutputField.Contain(name) {
return fmt.Errorf("duplicate function output %s", name)
return fmt.Errorf("duplicate function output field: function %s, field %s", function.GetName(), name)
}
usedOutputField.Insert(name)
}
@ -658,10 +677,6 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
if err := checkFunctionOutputField(function, outputFields); err != nil {
return err
}
if err := checkFunctionParams(function); err != nil {
return err
}
}
return nil
}
@ -670,19 +685,11 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 {
return fmt.Errorf("bm25 only need 1 output field, but now %d", len(fields))
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
}
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
return fmt.Errorf("bm25 only need sparse embedding output field, but now %s", fields[0].DataType.String())
}
if fields[0].GetIsPrimaryKey() {
return fmt.Errorf("bm25 output field can't be primary key")
}
if fields[0].GetIsPartitionKey() || fields[0].GetIsClusteringKey() {
return fmt.Errorf("bm25 output field can't be partition key or cluster key field")
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
}
default:
return fmt.Errorf("check output field for unknown function type")
@ -694,12 +701,12 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar {
return fmt.Errorf("only one VARCHAR input field is allowed for a BM25 Function, got %d field with type %s",
return fmt.Errorf("BM25 function input field must be a VARCHAR field, got %d field with type %s",
len(fields), fields[0].DataType.String())
}
h := typeutil.CreateFieldSchemaHelper(fields[0])
if !h.EnableTokenizer() {
return fmt.Errorf("BM25 input field must set enable_tokenizer to true")
return fmt.Errorf("BM25 function input field must set enable_tokenizer to true")
}
default:
@ -708,46 +715,40 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
return nil
}
func checkFunctionParams(function *schemapb.FunctionSchema) error {
func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
if function.GetName() == "" {
return fmt.Errorf("function name cannot be empty")
}
if len(function.GetInputFieldNames()) == 0 {
return fmt.Errorf("function input field names cannot be empty, function: %s", function.GetName())
}
if len(function.GetOutputFieldNames()) == 0 {
return fmt.Errorf("function output field names cannot be empty, function: %s", function.GetName())
}
for _, input := range function.GetInputFieldNames() {
if input == "" {
return fmt.Errorf("function input field name cannot be empty string, function: %s", function.GetName())
}
// if input occurs more than once, error
if lo.Count(function.GetInputFieldNames(), input) > 1 {
return fmt.Errorf("each function input field should be used exactly once in the same function, function: %s, input field: %s", function.GetName(), input)
}
}
for _, output := range function.GetOutputFieldNames() {
if output == "" {
return fmt.Errorf("function output field name cannot be empty string, function: %s", function.GetName())
}
if lo.Count(function.GetInputFieldNames(), output) > 0 {
return fmt.Errorf("a single field cannot be both input and output in the same function, function: %s, field: %s", function.GetName(), output)
}
if lo.Count(function.GetOutputFieldNames(), output) > 1 {
return fmt.Errorf("each function output field should be used exactly once in the same function, function: %s, output field: %s", function.GetName(), output)
}
}
switch function.GetType() {
case schemapb.FunctionType_BM25:
for _, kv := range function.GetParams() {
switch kv.GetKey() {
case "bm25_k1":
k1, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_k1 value, %w", err)
}
if k1 < 0 || k1 > 3 {
return fmt.Errorf("bm25_k1 must in [0,3] but now %f", k1)
}
case "bm25_b":
b, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_b value, %w", err)
}
if b < 0 || b > 1 {
return fmt.Errorf("bm25_b must in [0,1] but now %f", b)
}
case "bm25_avgdl":
avgdl, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_avgdl value, %w", err)
}
if avgdl <= 0 {
return fmt.Errorf("bm25_avgdl must large than zero but now %f", avgdl)
}
case "tokenizer_params":
// TODO ADD tokenizer check
default:
return fmt.Errorf("invalid function params, key: %s, value:%s", kv.GetKey(), kv.GetValue())
}
if len(function.GetParams()) != 0 {
return fmt.Errorf("BM25 function accepts no params")
}
default:
return fmt.Errorf("check function params with unknown function type")

View File

@ -2637,3 +2637,420 @@ func TestValidateLoadFieldsList(t *testing.T) {
})
}
}
func TestValidateFunction(t *testing.T) {
t.Run("Valid function schema", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.NoError(t, err)
})
t.Run("Invalid function schema - duplicate function names", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "duplicate function name")
})
t.Run("Invalid function schema - input field not found", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"non_existent_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "input field not found")
})
t.Run("Invalid function schema - output field not found", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"non_existent_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "output field not found")
})
t.Run("Invalid function schema - nullable input field", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}, Nullable: true},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "function input field cannot be nullable")
})
t.Run("Invalid function schema - output field is primary key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsPrimaryKey: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "function output field cannot be primary key")
})
t.Run("Invalid function schema - output field is partition key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsPartitionKey: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "function output field cannot be partition key or clustering key")
})
t.Run("Invalid function schema - output field is clustering key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsClusteringKey: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "function output field cannot be partition key or clustering key")
})
t.Run("Invalid function schema - nullable output field", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, Nullable: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "function output field cannot be nullable")
})
}
func TestValidateFunctionInputField(t *testing.T) {
t.Run("Valid BM25 function input", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
},
}
err := checkFunctionInputField(function, fields)
assert.NoError(t, err)
})
t.Run("Invalid BM25 function input - wrong data type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_Int64,
},
}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid BM25 function input - tokenizer not enabled", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "false"}},
},
}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid BM25 function input - multiple fields", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
},
{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{{Key: "enable_tokenizer", Value: "true"}},
},
}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
t.Run("Unknown function type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_Unknown,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_VarChar,
},
}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
}
func TestValidateFunctionOutputField(t *testing.T) {
t.Run("Valid BM25 function output", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_SparseFloatVector,
},
}
err := checkFunctionOutputField(function, fields)
assert.NoError(t, err)
})
t.Run("Invalid BM25 function output - wrong data type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_Float,
},
}
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid BM25 function output - multiple fields", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_BM25,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_SparseFloatVector,
},
{
DataType: schemapb.DataType_FloatVector,
},
}
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
t.Run("Unknown function type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_Unknown,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_FloatVector,
},
}
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
}
func TestValidateFunctionBasicParams(t *testing.T) {
t.Run("Valid function", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "validFunction",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1", "input2"},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.NoError(t, err)
})
t.Run("Empty function name", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1"},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Empty input field names", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "emptyInputs",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Empty output field names", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "emptyOutputs",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1"},
OutputFieldNames: []string{},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Empty input field name", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "emptyInputName",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1", ""},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Duplicate input field names", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "duplicateInputs",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1", "input1"},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Empty output field name", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "emptyOutputName",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1"},
OutputFieldNames: []string{"output1", ""},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Input field used as output", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "inputAsOutput",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"field1", "field2"},
OutputFieldNames: []string{"field1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Duplicate output field names", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "duplicateOutputs",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input1"},
OutputFieldNames: []string{"output1", "output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
}

View File

@ -143,7 +143,7 @@ type shardDelegator struct {
// fieldId -> functionRunner map for search function field
functionRunners map[UniqueID]function.FunctionRunner
hasBM25Field bool
isBM25Field map[UniqueID]bool
}
// getLogger returns the zap logger with pre-defined shard attributes.
@ -245,7 +245,7 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
}
// build idf for bm25 search
if req.GetReq().GetMetricType() == metric.BM25 {
if req.GetReq().GetMetricType() == metric.BM25 || (req.GetReq().GetMetricType() == metric.EMPTY && sd.isBM25Field[req.GetReq().GetFieldId()]) {
avgdl, err := sd.buildBM25IDF(req.GetReq())
if err != nil {
return nil, err
@ -908,6 +908,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
excludedSegments: excludedSegments,
functionRunners: make(map[int64]function.FunctionRunner),
isBM25Field: make(map[int64]bool),
}
for _, tf := range collection.Schema().GetFunctions() {
@ -918,7 +919,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
}
sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner
if tf.GetType() == schemapb.FunctionType_BM25 {
sd.hasBM25Field = true
sd.isBM25Field[tf.OutputFieldIds[0]] = true
}
}
}

View File

@ -385,7 +385,7 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
for _, segment := range loaded {
sd.pkOracle.Register(segment, paramtable.GetNodeID())
if sd.hasBM25Field {
if len(sd.isBM25Field) > 0 {
sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing)
}
}
@ -485,7 +485,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
})
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
if sd.hasBM25Field {
if len(sd.isBM25Field) > 0 {
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
if err != nil {
log.Warn("failed to load bm25 stats for segment", zap.Error(err))

View File

@ -68,14 +68,17 @@ func NewInsertDataWithCap(schema *schemapb.CollectionSchema, cap int, withFuncti
for _, field := range schema.Fields {
if field.IsPrimaryKey && field.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("primary key field not support nullable")
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("primary key field should not be nullable (field: %s)", field.Name))
}
if field.IsPartitionKey && field.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("partition key field not support nullable")
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("partition key field should not be nullable (field: %s)", field.Name))
}
if field.IsFunctionOutput {
if field.IsPrimaryKey || field.IsPartitionKey {
return nil, merr.WrapErrParameterInvalidMsg("function output field should not be primary key or partition key")
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("function output field should not be primary key or partition key (field: %s)", field.Name))
}
if field.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("function output field should not be nullable (field: %s)", field.Name))
}
if !withFunctionOutput {
continue

View File

@ -38,4 +38,6 @@ const (
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
BM25 MetricType = "BM25"
EMPTY MetricType = ""
)