From 600db9d99e19e1d399aaa2e26766f53a4ff489b6 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Mon, 6 May 2024 21:15:29 +0800 Subject: [PATCH] enhance: check field data type of input (#32777) issue: #32769 Signed-off-by: longjiquan --- internal/proxy/validate_util.go | 74 +++++++++++-- internal/proxy/validate_util_test.go | 152 ++++++++++++++++++++++++++- 2 files changed, 213 insertions(+), 13 deletions(-) diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 15b371e149..f250dbce92 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -97,10 +97,22 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col if err := v.checkJSONFieldData(field, fieldSchema); err != nil { return err } - case schemapb.DataType_Int8, schemapb.DataType_Int16: + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: if err := v.checkIntegerFieldData(field, fieldSchema); err != nil { return err } + case schemapb.DataType_Int64: + if err := v.checkLongFieldData(field, fieldSchema); err != nil { + return err + } + case schemapb.DataType_Float: + if err := v.checkFloatFieldData(field, fieldSchema); err != nil { + return err + } + case schemapb.DataType_Double: + if err := v.checkDoubleFieldData(field, fieldSchema); err != nil { + return err + } case schemapb.DataType_Array: if err := v.checkArrayFieldData(field, fieldSchema); err != nil { return err @@ -366,7 +378,11 @@ func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, f } func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { - // TODO + bVecArray := field.GetVectors().GetBinaryVector() + if bVecArray == nil { + msg := fmt.Sprintf("binary float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need bytes array", "got nil", msg) + } return nil } @@ -449,21 +465,57 @@ func (v *validateUtil) checkJSONFieldData(field *schemapb.FieldData, fieldSchema } func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { - if !v.checkOverflow { - return nil - } - data := field.GetScalars().GetIntData().GetData() if data == nil && fieldSchema.GetDefaultValue() == nil { msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName()) return merr.WrapErrParameterInvalid("need int array", "got nil", msg) } - switch fieldSchema.GetDataType() { - case schemapb.DataType_Int8: - return verifyOverflowByRange(data, math.MinInt8, math.MaxInt8) - case schemapb.DataType_Int16: - return verifyOverflowByRange(data, math.MinInt16, math.MaxInt16) + if v.checkOverflow { + switch fieldSchema.GetDataType() { + case schemapb.DataType_Int8: + return verifyOverflowByRange(data, math.MinInt8, math.MaxInt8) + case schemapb.DataType_Int16: + return verifyOverflowByRange(data, math.MinInt16, math.MaxInt16) + } + } + + return nil +} + +func (v *validateUtil) checkLongFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + data := field.GetScalars().GetLongData().GetData() + if data == nil && fieldSchema.GetDefaultValue() == nil { + msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need long int array", "got nil", msg) + } + + return nil +} + +func (v *validateUtil) checkFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + data := field.GetScalars().GetFloatData().GetData() + if data == nil && fieldSchema.GetDefaultValue() == nil { + msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need float32 array", "got nil", msg) + } + + if v.checkNAN { + return typeutil.VerifyFloats32(data) + } + + return nil +} + +func (v *validateUtil) checkDoubleFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + data := field.GetScalars().GetDoubleData().GetData() + if data == nil && fieldSchema.GetDefaultValue() == nil { + msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need float64(double) array", "got nil", msg) + } + + if v.checkNAN { + return typeutil.VerifyFloats64(data) } return nil diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 8a4329a0c2..e1363444af 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -187,7 +187,16 @@ func Test_validateUtil_checkVarCharFieldData(t *testing.T) { } func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) { - assert.NoError(t, newValidateUtil().checkBinaryVectorFieldData(nil, nil)) + v := newValidateUtil() + assert.Error(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)) + assert.NoError(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: []byte(strings.Repeat("1", 128)), + }, + }, + }}, nil)) } func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { @@ -2539,6 +2548,45 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, }, + { + FieldName: "test6", + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{(math.MinInt8) + 1, (math.MaxInt8) - 1}, + }, + }, + }, + }, + }, + { + FieldName: "test7", + Type: schemapb.DataType_Float, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: generateFloat32Array(2), + }, + }, + }, + }, + }, + { + FieldName: "test8", + Type: schemapb.DataType_Double, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: generateFloat64Array(2), + }, + }, + }, + }, + }, } schema := &schemapb.CollectionSchema{ @@ -2662,6 +2710,21 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, }, + { + Name: "test6", + FieldID: 112, + DataType: schemapb.DataType_Int64, + }, + { + Name: "test7", + FieldID: 113, + DataType: schemapb.DataType_Float, + }, + { + Name: "test8", + FieldID: 114, + DataType: schemapb.DataType_Double, + }, }, } @@ -3599,7 +3662,16 @@ func Test_verifyOverflowByRange(t *testing.T) { func Test_validateUtil_checkIntegerFieldData(t *testing.T) { t.Run("no check", func(t *testing.T) { v := newValidateUtil() - assert.NoError(t, v.checkIntegerFieldData(nil, nil)) + assert.Error(t, v.checkIntegerFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{}}, nil)) + assert.NoError(t, v.checkIntegerFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3, 4}, + }, + }, + }, + }}, nil)) }) t.Run("tiny int, type mismatch", func(t *testing.T) { @@ -3811,3 +3883,79 @@ func Test_validateUtil_checkJSONData(t *testing.T) { assert.Error(t, err) }) } + +func Test_validateUtil_checkLongFieldData(t *testing.T) { + v := newValidateUtil() + assert.Error(t, v.checkLongFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{}, + }, nil)) + assert.NoError(t, v.checkLongFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, nil)) +} + +func Test_validateUtil_checkFloatFieldData(t *testing.T) { + v := newValidateUtil(withNANCheck()) + assert.Error(t, v.checkFloatFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{}, + }, nil)) + assert.NoError(t, v.checkFloatFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1, 2, 3, 4}, + }, + }, + }, + }, + }, nil)) + assert.Error(t, v.checkFloatFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{float32(math.NaN())}, + }, + }, + }, + }, + }, nil)) +} + +func Test_validateUtil_checkDoubleFieldData(t *testing.T) { + v := newValidateUtil(withNANCheck()) + assert.Error(t, v.checkDoubleFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{}, + }, nil)) + assert.NoError(t, v.checkDoubleFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1, 2, 3, 4}, + }, + }, + }, + }, + }, nil)) + assert.Error(t, v.checkDoubleFieldData(&schemapb.FieldData{ + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{math.NaN()}, + }, + }, + }, + }, + }, nil)) +}