enhance: check field data type of input (#32777)

issue: #32769

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/32792/head
Jiquan Long 2024-05-06 21:15:29 +08:00 committed by GitHub
parent c5191e9b28
commit 600db9d99e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 213 additions and 13 deletions

View File

@ -97,10 +97,22 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
if err := v.checkJSONFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Int8, schemapb.DataType_Int16:
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
if err := v.checkIntegerFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Int64:
if err := v.checkLongFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Float:
if err := v.checkFloatFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Double:
if err := v.checkDoubleFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Array:
if err := v.checkArrayFieldData(field, fieldSchema); err != nil {
return err
@ -366,7 +378,11 @@ func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, f
}
func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
// TODO
bVecArray := field.GetVectors().GetBinaryVector()
if bVecArray == nil {
msg := fmt.Sprintf("binary float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need bytes array", "got nil", msg)
}
return nil
}
@ -449,21 +465,57 @@ func (v *validateUtil) checkJSONFieldData(field *schemapb.FieldData, fieldSchema
}
func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
if !v.checkOverflow {
return nil
}
data := field.GetScalars().GetIntData().GetData()
if data == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need int array", "got nil", msg)
}
switch fieldSchema.GetDataType() {
case schemapb.DataType_Int8:
return verifyOverflowByRange(data, math.MinInt8, math.MaxInt8)
case schemapb.DataType_Int16:
return verifyOverflowByRange(data, math.MinInt16, math.MaxInt16)
if v.checkOverflow {
switch fieldSchema.GetDataType() {
case schemapb.DataType_Int8:
return verifyOverflowByRange(data, math.MinInt8, math.MaxInt8)
case schemapb.DataType_Int16:
return verifyOverflowByRange(data, math.MinInt16, math.MaxInt16)
}
}
return nil
}
func (v *validateUtil) checkLongFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
data := field.GetScalars().GetLongData().GetData()
if data == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need long int array", "got nil", msg)
}
return nil
}
func (v *validateUtil) checkFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
data := field.GetScalars().GetFloatData().GetData()
if data == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need float32 array", "got nil", msg)
}
if v.checkNAN {
return typeutil.VerifyFloats32(data)
}
return nil
}
func (v *validateUtil) checkDoubleFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
data := field.GetScalars().GetDoubleData().GetData()
if data == nil && fieldSchema.GetDefaultValue() == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need float64(double) array", "got nil", msg)
}
if v.checkNAN {
return typeutil.VerifyFloats64(data)
}
return nil

View File

@ -187,7 +187,16 @@ func Test_validateUtil_checkVarCharFieldData(t *testing.T) {
}
func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) {
assert.NoError(t, newValidateUtil().checkBinaryVectorFieldData(nil, nil))
v := newValidateUtil()
assert.Error(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil))
assert.NoError(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte(strings.Repeat("1", 128)),
},
},
}}, nil))
}
func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
@ -2539,6 +2548,45 @@ func Test_validateUtil_Validate(t *testing.T) {
},
},
},
{
FieldName: "test6",
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{(math.MinInt8) + 1, (math.MaxInt8) - 1},
},
},
},
},
},
{
FieldName: "test7",
Type: schemapb.DataType_Float,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: generateFloat32Array(2),
},
},
},
},
},
{
FieldName: "test8",
Type: schemapb.DataType_Double,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: generateFloat64Array(2),
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
@ -2662,6 +2710,21 @@ func Test_validateUtil_Validate(t *testing.T) {
},
},
},
{
Name: "test6",
FieldID: 112,
DataType: schemapb.DataType_Int64,
},
{
Name: "test7",
FieldID: 113,
DataType: schemapb.DataType_Float,
},
{
Name: "test8",
FieldID: 114,
DataType: schemapb.DataType_Double,
},
},
}
@ -3599,7 +3662,16 @@ func Test_verifyOverflowByRange(t *testing.T) {
func Test_validateUtil_checkIntegerFieldData(t *testing.T) {
t.Run("no check", func(t *testing.T) {
v := newValidateUtil()
assert.NoError(t, v.checkIntegerFieldData(nil, nil))
assert.Error(t, v.checkIntegerFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{}}, nil))
assert.NoError(t, v.checkIntegerFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{1, 2, 3, 4},
},
},
},
}}, nil))
})
t.Run("tiny int, type mismatch", func(t *testing.T) {
@ -3811,3 +3883,79 @@ func Test_validateUtil_checkJSONData(t *testing.T) {
assert.Error(t, err)
})
}
func Test_validateUtil_checkLongFieldData(t *testing.T) {
v := newValidateUtil()
assert.Error(t, v.checkLongFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{},
}, nil))
assert.NoError(t, v.checkLongFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
}, nil))
}
func Test_validateUtil_checkFloatFieldData(t *testing.T) {
v := newValidateUtil(withNANCheck())
assert.Error(t, v.checkFloatFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{},
}, nil))
assert.NoError(t, v.checkFloatFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{1, 2, 3, 4},
},
},
},
},
}, nil))
assert.Error(t, v.checkFloatFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{float32(math.NaN())},
},
},
},
},
}, nil))
}
func Test_validateUtil_checkDoubleFieldData(t *testing.T) {
v := newValidateUtil(withNANCheck())
assert.Error(t, v.checkDoubleFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{},
}, nil))
assert.NoError(t, v.checkDoubleFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: []float64{1, 2, 3, 4},
},
},
},
},
}, nil))
assert.Error(t, v.checkDoubleFieldData(&schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: []float64{math.NaN()},
},
},
},
},
}, nil))
}