Check dimension of inserted records (#22819)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/22832/head
Jiquan Long 2023-03-17 17:33:58 +08:00 committed by GitHub
parent 4a90490a67
commit dff15c3488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 5 deletions

View File

@ -82,6 +82,9 @@ func (it *insertTask) OnEnqueue() error {
}
func (it *insertTask) checkVectorFieldData() error {
// error won't happen here.
helper, _ := typeutil.CreateSchemaHelper(it.schema)
fields := it.insertMsg.GetFieldsData()
for _, field := range fields {
if field.GetType() != schemapb.DataType_FloatVector {
@ -90,14 +93,21 @@ func (it *insertTask) checkVectorFieldData() error {
vectorField := field.GetVectors()
if vectorField == nil || vectorField.GetFloatVector() == nil {
log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName()))
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
}
// error won't happen here.
f, _ := helper.GetFieldFromName(field.GetFieldName())
dim, _ := typeutil.GetDim(f)
floatArray := vectorField.GetFloatVector()
err := typeutil.VerifyFloats32(floatArray.GetData())
if err != nil {
log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err))
// TODO: `NumRows` passed by client may be not trustable.
if uint64(len(floatArray.GetData())) != uint64(dim)*it.insertMsg.GetNumRows() {
return fmt.Errorf("length of inserted vector (%d) not match dim (%d)", len(floatArray.GetData()), dim)
}
if err := typeutil.VerifyFloats32(floatArray.GetData()); err != nil {
return fmt.Errorf("float vector field data is illegal, error: %w", err)
}
}

View File

@ -2,6 +2,7 @@ package proxy
import (
"math"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
@ -245,6 +246,7 @@ func TestInsertTask_CheckVectorFieldData(t *testing.T) {
IsPrimaryKey: false,
AutoID: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: strconv.Itoa(dim)}},
},
},
},
@ -300,4 +302,11 @@ func TestInsertTask_CheckVectorFieldData(t *testing.T) {
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector dim not match
task.insertMsg.FieldsData = []*schemapb.FieldData{
newFloatVectorFieldData(fieldName, numRows, dim+1),
}
err = task.checkVectorFieldData()
assert.Error(t, err)
}