mirror of https://github.com/milvus-io/milvus.git
Add length check when insert and upsert (#24759)
Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/24906/head
parent
56380ea75b
commit
db31e88a73
|
@ -200,7 +200,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
|
||||
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
|
||||
Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -207,7 +207,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
|
||||
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
|
||||
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema, it.upsertMsg.InsertMsg.NRows()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
|
@ -51,15 +52,6 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
|
|||
return err
|
||||
}
|
||||
|
||||
err = v.fillWithDefaultValue(data, helper, numRows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.checkAligned(data, helper, numRows); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, field := range data {
|
||||
fieldSchema, err := helper.GetFieldFromName(field.GetFieldName())
|
||||
if err != nil {
|
||||
|
@ -91,6 +83,15 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
|
|||
}
|
||||
}
|
||||
|
||||
err = v.fillWithDefaultValue(data, helper, numRows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.checkAligned(data, helper, numRows); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -255,13 +256,13 @@ func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fie
|
|||
|
||||
func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
strArr := field.GetScalars().GetStringData().GetData()
|
||||
if strArr == nil {
|
||||
if strArr == nil && fieldSchema.GetDefaultValue() == nil {
|
||||
msg := fmt.Sprintf("varchar field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need string array", "got nil", msg)
|
||||
}
|
||||
|
||||
if v.checkMaxLen {
|
||||
maxLength, err := GetMaxLength(fieldSchema)
|
||||
maxLength, err := parameterutil.GetMaxLength(fieldSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -291,7 +292,7 @@ func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSch
|
|||
}
|
||||
|
||||
data := field.GetScalars().GetIntData().GetData()
|
||||
if data == nil {
|
||||
if data == nil && fieldSchema.GetDefaultValue() == nil {
|
||||
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need int array", "got nil", msg)
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
@ -79,7 +80,7 @@ func (t *createCollectionTask) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func defaultValueTypeMatch(schema *schemapb.CollectionSchema) error {
|
||||
func checkDefaultValue(schema *schemapb.CollectionSchema) error {
|
||||
for _, fieldSchema := range schema.Fields {
|
||||
if fieldSchema.GetDefaultValue() != nil {
|
||||
switch fieldSchema.GetDefaultValue().Data.(type) {
|
||||
|
@ -120,6 +121,15 @@ func defaultValueTypeMatch(schema *schemapb.CollectionSchema) error {
|
|||
if fieldSchema.GetDataType() != schemapb.DataType_VarChar {
|
||||
return merr.WrapErrParameterInvalid("DataType_VarChar", "not match", "default value type mismatches field schema type")
|
||||
}
|
||||
maxLength, err := parameterutil.GetMaxLength(fieldSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defaultValueLength := len(fieldSchema.GetDefaultValue().GetStringData())
|
||||
if int64(defaultValueLength) > maxLength {
|
||||
msg := fmt.Sprintf("the length (%d) of string exceeds max length (%d)", defaultValueLength, maxLength)
|
||||
return merr.WrapErrParameterInvalid("valid length string", "string length exceeds max length", msg)
|
||||
}
|
||||
default:
|
||||
panic("default value unsupport data type")
|
||||
}
|
||||
|
@ -146,9 +156,9 @@ func (t *createCollectionTask) validateSchema(schema *schemapb.CollectionSchema)
|
|||
return merr.WrapErrParameterInvalid("collection name matches schema name", "don't match", msg)
|
||||
}
|
||||
|
||||
err := defaultValueTypeMatch(schema)
|
||||
err := checkDefaultValue(schema)
|
||||
if err != nil {
|
||||
log.Error("default value type mismatch field schema type")
|
||||
log.Error("has invalid default value")
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -266,6 +266,37 @@ func Test_createCollectionTask_validateSchema(t *testing.T) {
|
|||
assert.ErrorIs(t, err8, merr.ErrParameterInvalid)
|
||||
})
|
||||
|
||||
t.Run("default value length exceeds", func(t *testing.T) {
|
||||
collectionName := funcutil.GenRandomStr()
|
||||
task := createCollectionTask{
|
||||
Req: &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MaxLengthKey,
|
||||
Value: "2",
|
||||
},
|
||||
},
|
||||
DefaultValue: &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_StringData{
|
||||
StringData: "abc",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.validateSchema(schema)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
collectionName := funcutil.GenRandomStr()
|
||||
task := createCollectionTask{
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package proxy
|
||||
package parameterutil
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package proxy
|
||||
package parameterutil
|
||||
|
||||
import (
|
||||
"testing"
|
Loading…
Reference in New Issue