Check dimension of inserted records (#22819)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/22832/head
Jiquan Long 2023-03-17 17:33:58 +08:00 committed by GitHub
parent 4a90490a67
commit dff15c3488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 5 deletions

View File

@ -31,4 +31,4 @@ IndexBuilderSetSimdType(const char* value) {
void void
IndexBuilderInitGPU(const int32_t gpu_id, const int32_t res_num) { IndexBuilderInitGPU(const int32_t gpu_id, const int32_t res_num) {
milvus::config::KnowhereInitGPU(gpu_id, res_num); milvus::config::KnowhereInitGPU(gpu_id, res_num);
} }

View File

@ -82,6 +82,9 @@ func (it *insertTask) OnEnqueue() error {
} }
func (it *insertTask) checkVectorFieldData() error { func (it *insertTask) checkVectorFieldData() error {
// error won't happen here.
helper, _ := typeutil.CreateSchemaHelper(it.schema)
fields := it.insertMsg.GetFieldsData() fields := it.insertMsg.GetFieldsData()
for _, field := range fields { for _, field := range fields {
if field.GetType() != schemapb.DataType_FloatVector { if field.GetType() != schemapb.DataType_FloatVector {
@ -90,14 +93,21 @@ func (it *insertTask) checkVectorFieldData() error {
vectorField := field.GetVectors() vectorField := field.GetVectors()
if vectorField == nil || vectorField.GetFloatVector() == nil { if vectorField == nil || vectorField.GetFloatVector() == nil {
log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName()))
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
} }
// error won't happen here.
f, _ := helper.GetFieldFromName(field.GetFieldName())
dim, _ := typeutil.GetDim(f)
floatArray := vectorField.GetFloatVector() floatArray := vectorField.GetFloatVector()
err := typeutil.VerifyFloats32(floatArray.GetData())
if err != nil { // TODO: `NumRows` passed by client may be not trustable.
log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err)) if uint64(len(floatArray.GetData())) != uint64(dim)*it.insertMsg.GetNumRows() {
return fmt.Errorf("length of inserted vector (%d) not match dim (%d)", len(floatArray.GetData()), dim)
}
if err := typeutil.VerifyFloats32(floatArray.GetData()); err != nil {
return fmt.Errorf("float vector field data is illegal, error: %w", err) return fmt.Errorf("float vector field data is illegal, error: %w", err)
} }
} }

View File

@ -2,6 +2,7 @@ package proxy
import ( import (
"math" "math"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -245,6 +246,7 @@ func TestInsertTask_CheckVectorFieldData(t *testing.T) {
IsPrimaryKey: false, IsPrimaryKey: false,
AutoID: false, AutoID: false,
DataType: schemapb.DataType_FloatVector, DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: strconv.Itoa(dim)}},
}, },
}, },
}, },
@ -300,4 +302,11 @@ func TestInsertTask_CheckVectorFieldData(t *testing.T) {
} }
err = task.checkVectorFieldData() err = task.checkVectorFieldData()
assert.Error(t, err) assert.Error(t, err)
// vector dim not match
task.insertMsg.FieldsData = []*schemapb.FieldData{
newFloatVectorFieldData(fieldName, numRows, dim+1),
}
err = task.checkVectorFieldData()
assert.Error(t, err)
} }