Add length check when insert and upsert (#24759)

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/24906/head
smellthemoon 2023-06-15 10:24:38 +08:00 committed by GitHub
parent 56380ea75b
commit db31e88a73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 19 deletions

View File

@ -200,7 +200,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
}
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil {
return err
}

View File

@ -207,7 +207,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
}
}
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema, it.upsertMsg.InsertMsg.NRows()); err != nil {
return err
}

View File

@ -8,6 +8,7 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/zap"
@ -51,15 +52,6 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
return err
}
err = v.fillWithDefaultValue(data, helper, numRows)
if err != nil {
return err
}
if err := v.checkAligned(data, helper, numRows); err != nil {
return err
}
for _, field := range data {
fieldSchema, err := helper.GetFieldFromName(field.GetFieldName())
if err != nil {
@ -91,6 +83,15 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
}
}
err = v.fillWithDefaultValue(data, helper, numRows)
if err != nil {
return err
}
if err := v.checkAligned(data, helper, numRows); err != nil {
return err
}
return nil
}
@ -255,13 +256,13 @@ func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fie
func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil {
if strArr == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("varchar field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need string array", "got nil", msg)
}
if v.checkMaxLen {
maxLength, err := GetMaxLength(fieldSchema)
maxLength, err := parameterutil.GetMaxLength(fieldSchema)
if err != nil {
return err
}
@ -291,7 +292,7 @@ func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSch
}
data := field.GetScalars().GetIntData().GetData()
if data == nil {
if data == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need int array", "got nil", msg)
}

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
@ -79,7 +80,7 @@ func (t *createCollectionTask) validate() error {
return nil
}
func defaultValueTypeMatch(schema *schemapb.CollectionSchema) error {
func checkDefaultValue(schema *schemapb.CollectionSchema) error {
for _, fieldSchema := range schema.Fields {
if fieldSchema.GetDefaultValue() != nil {
switch fieldSchema.GetDefaultValue().Data.(type) {
@ -120,6 +121,15 @@ func defaultValueTypeMatch(schema *schemapb.CollectionSchema) error {
if fieldSchema.GetDataType() != schemapb.DataType_VarChar {
return merr.WrapErrParameterInvalid("DataType_VarChar", "not match", "default value type mismatches field schema type")
}
maxLength, err := parameterutil.GetMaxLength(fieldSchema)
if err != nil {
return err
}
defaultValueLength := len(fieldSchema.GetDefaultValue().GetStringData())
if int64(defaultValueLength) > maxLength {
msg := fmt.Sprintf("the length (%d) of string exceeds max length (%d)", defaultValueLength, maxLength)
return merr.WrapErrParameterInvalid("valid length string", "string length exceeds max length", msg)
}
default:
panic("default value unsupport data type")
}
@ -146,9 +156,9 @@ func (t *createCollectionTask) validateSchema(schema *schemapb.CollectionSchema)
return merr.WrapErrParameterInvalid("collection name matches schema name", "don't match", msg)
}
err := defaultValueTypeMatch(schema)
err := checkDefaultValue(schema)
if err != nil {
log.Error("default value type mismatch field schema type")
log.Error("has invalid default value")
return err
}

View File

@ -266,6 +266,37 @@ func Test_createCollectionTask_validateSchema(t *testing.T) {
assert.ErrorIs(t, err8, merr.ErrParameterInvalid)
})
t.Run("default value length exceeds", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
task := createCollectionTask{
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
CollectionName: collectionName,
},
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
Fields: []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "2",
},
},
DefaultValue: &schemapb.ValueField{
Data: &schemapb.ValueField_StringData{
StringData: "abc",
},
},
},
},
}
err := task.validateSchema(schema)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
t.Run("normal case", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
task := createCollectionTask{

View File

@ -1,4 +1,4 @@
package proxy
package parameterutil
import (
"fmt"

View File

@ -1,4 +1,4 @@
package proxy
package parameterutil
import (
"testing"