Check overflow for inserted integer (#24142)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/24152/head
Jiquan Long 2023-05-16 20:19:22 +08:00 committed by GitHub
parent c31af5edfb
commit 6965495b9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 267 additions and 7 deletions

View File

@ -169,7 +169,8 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err
}
if err := newValidateUtil(withNANCheck()).Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil {
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil {
return err
}

View File

@ -194,7 +194,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err
}
if err := newValidateUtil(withNANCheck()).
if err := newValidateUtil(withNANCheck(), withOverflowCheck()).
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema, it.upsertMsg.InsertMsg.NRows()); err != nil {
return err
}

View File

@ -2,6 +2,7 @@ package proxy
import (
"fmt"
"math"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/log"
@ -13,8 +14,9 @@ import (
)
type validateUtil struct {
checkNAN bool
checkMaxLen bool
checkNAN bool
checkMaxLen bool
checkOverflow bool
}
type validateOption func(*validateUtil)
@ -31,6 +33,12 @@ func withMaxLenCheck() validateOption {
}
}
func withOverflowCheck() validateOption {
return func(v *validateUtil) {
v.checkOverflow = true
}
}
func (v *validateUtil) apply(opts ...validateOption) {
for _, opt := range opts {
opt(v)
@ -75,6 +83,10 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
if err := v.checkJSONFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Int8, schemapb.DataType_Int16:
if err := v.checkIntegerFieldData(field, fieldSchema); err != nil {
return err
}
default:
}
}
@ -273,6 +285,27 @@ func (v *validateUtil) checkJSONFieldData(field *schemapb.FieldData, fieldSchema
return nil
}
func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
if !v.checkOverflow {
return nil
}
data := field.GetScalars().GetIntData().GetData()
if data == nil {
msg := fmt.Sprintf("field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need int array", "got nil", msg)
}
switch fieldSchema.GetDataType() {
case schemapb.DataType_Int8:
return verifyOverflowByRange(data, math.MinInt8, math.MaxInt8)
case schemapb.DataType_Int16:
return verifyOverflowByRange(data, math.MinInt16, math.MaxInt16)
}
return nil
}
func verifyLengthPerRow[E interface{ ~string | ~[]byte }](strArr []E, maxLength int64) error {
for i, s := range strArr {
if int64(len(s)) > maxLength {
@ -284,10 +317,21 @@ func verifyLengthPerRow[E interface{ ~string | ~[]byte }](strArr []E, maxLength
return nil
}
func verifyOverflowByRange(arr []int32, lb int64, ub int64) error {
for idx, e := range arr {
if lb > int64(e) || ub < int64(e) {
msg := fmt.Sprintf("the %dth element (%d) out of range: [%d, %d]", idx, e, lb, ub)
return merr.WrapErrParameterInvalid("integer doesn't overflow", "out of range", msg)
}
}
return nil
}
func newValidateUtil(opts ...validateOption) *validateUtil {
v := &validateUtil{
checkNAN: true,
checkMaxLen: false,
checkNAN: true,
checkMaxLen: false,
checkOverflow: false,
}
v.apply(opts...)

View File

@ -853,6 +853,39 @@ func Test_validateUtil_Validate(t *testing.T) {
assert.Error(t, err)
})
t.Run("has overflow", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test1",
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt8) - 1, int32(math.MaxInt8) + 1},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test1",
FieldID: 101,
DataType: schemapb.DataType_Int8,
},
},
}
v := newValidateUtil(withOverflowCheck())
err := v.Validate(data, schema, 2)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
data := []*schemapb.FieldData{
{
@ -905,6 +938,19 @@ func Test_validateUtil_Validate(t *testing.T) {
},
},
},
{
FieldName: "test5",
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt8) + 1, int32(math.MaxInt8) - 1},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
@ -947,10 +993,15 @@ func Test_validateUtil_Validate(t *testing.T) {
FieldID: 104,
DataType: schemapb.DataType_JSON,
},
{
Name: "test5",
FieldID: 105,
DataType: schemapb.DataType_Int8,
},
},
}
v := newValidateUtil(withNANCheck(), withMaxLenCheck())
v := newValidateUtil(withNANCheck(), withMaxLenCheck(), withOverflowCheck())
err := v.Validate(data, schema, 2)
@ -1835,3 +1886,167 @@ func Test_validateUtil_fillWithDefaultValue(t *testing.T) {
})
}
func Test_verifyOverflowByRange(t *testing.T) {
var err error
err = verifyOverflowByRange(
[]int32{int32(math.MinInt8 - 1)},
math.MinInt8,
math.MaxInt8,
)
assert.Error(t, err)
err = verifyOverflowByRange(
[]int32{int32(math.MaxInt8 + 1)},
math.MinInt8,
math.MaxInt8,
)
assert.Error(t, err)
err = verifyOverflowByRange(
[]int32{int32(math.MinInt8 - 1), int32(math.MaxInt8 + 1)},
math.MinInt8,
math.MaxInt8,
)
assert.Error(t, err)
err = verifyOverflowByRange(
[]int32{int32(math.MaxInt8 + 1), int32(math.MinInt8 - 1)},
math.MinInt8,
math.MaxInt8,
)
assert.Error(t, err)
err = verifyOverflowByRange(
[]int32{1, 2, 3, int32(math.MinInt8 - 1), int32(math.MaxInt8 + 1)},
math.MinInt8,
math.MaxInt8,
)
assert.Error(t, err)
err = verifyOverflowByRange(
[]int32{1, 2, 3, int32(math.MinInt8 + 1), int32(math.MaxInt8 - 1)},
math.MinInt8,
math.MaxInt8,
)
assert.NoError(t, err)
}
func Test_validateUtil_checkIntegerFieldData(t *testing.T) {
t.Run("no check", func(t *testing.T) {
v := newValidateUtil()
assert.NoError(t, v.checkIntegerFieldData(nil, nil))
})
t.Run("tiny int, type mismatch", func(t *testing.T) {
v := newValidateUtil(withOverflowCheck())
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int8,
}
data := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{},
},
},
}
err := v.checkIntegerFieldData(data, f)
assert.Error(t, err)
})
t.Run("tiny int, overflow", func(t *testing.T) {
v := newValidateUtil(withOverflowCheck())
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int8,
}
data := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt8 - 1)},
},
},
},
},
}
err := v.checkIntegerFieldData(data, f)
assert.Error(t, err)
})
t.Run("tiny int, normal case", func(t *testing.T) {
v := newValidateUtil(withOverflowCheck())
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int8,
}
data := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt8 + 1), int32(math.MaxInt8 - 1)},
},
},
},
},
}
err := v.checkIntegerFieldData(data, f)
assert.NoError(t, err)
})
t.Run("small int, overflow", func(t *testing.T) {
v := newValidateUtil(withOverflowCheck())
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int16,
}
data := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt16 - 1)},
},
},
},
},
}
err := v.checkIntegerFieldData(data, f)
assert.Error(t, err)
})
t.Run("small int, normal case", func(t *testing.T) {
v := newValidateUtil(withOverflowCheck())
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int16,
}
data := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{int32(math.MinInt16 + 1), int32(math.MaxInt16 - 1)},
},
},
},
},
}
err := v.checkIntegerFieldData(data, f)
assert.NoError(t, err)
})
}