Check if all columns aligned with same num_rows (#22968) (#22981)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/22987/head
Jiquan Long 2023-03-24 17:32:00 +08:00 committed by GitHub
parent 85bbe19f34
commit 9c3c29db6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1204 additions and 137 deletions

View File

@ -0,0 +1,31 @@
package proxy
import (
"errors"
"fmt"
"strconv"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
// GetMaxLength get max length of field. Maybe also helpful outside.
func GetMaxLength(field *schemapb.FieldSchema) (int64, error) {
if !typeutil.IsStringType(field.GetDataType()) {
msg := fmt.Sprintf("%s is not of string type", field.GetDataType())
return 0, errors.New(msg)
}
h := typeutil.NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...))
maxLengthStr, err := h.Get("max_length")
if err != nil {
msg := "max length not found"
return 0, errors.New(msg)
}
maxLength, err := strconv.Atoi(maxLengthStr)
if err != nil {
msg := fmt.Sprintf("invalid max length: %s", maxLengthStr)
return 0, errors.New(msg)
}
return int64(maxLength), nil
}

View File

@ -0,0 +1,57 @@
package proxy
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func TestGetMaxLength(t *testing.T) {
t.Run("not string type", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Bool,
}
_, err := GetMaxLength(f)
assert.Error(t, err)
})
t.Run("max length not found", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
}
_, err := GetMaxLength(f)
assert.Error(t, err)
})
t.Run("max length not int", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "not_int_aha",
},
},
}
_, err := GetMaxLength(f)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "100",
},
},
}
maxLength, err := GetMaxLength(f)
assert.NoError(t, err)
assert.Equal(t, int64(100), maxLength)
})
}

View File

@ -148,40 +148,6 @@ func (it *insertTask) checkPrimaryFieldData() error {
return nil
}
func (it *insertTask) checkVectorFieldData() error {
// error won't happen here.
helper, _ := typeutil.CreateSchemaHelper(it.schema)
fields := it.GetFieldsData()
for _, field := range fields {
if field.GetType() != schemapb.DataType_FloatVector {
continue
}
vectorField := field.GetVectors()
if vectorField == nil || vectorField.GetFloatVector() == nil {
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
}
// error won't happen here.
f, _ := helper.GetFieldFromName(field.GetFieldName())
dim, _ := typeutil.GetDim(f)
floatArray := vectorField.GetFloatVector()
// TODO: `NumRows` passed by client may be not trustable.
if uint64(len(floatArray.GetData())) != uint64(dim)*it.BaseInsertTask.GetNumRows() {
return fmt.Errorf("length of inserted vector (%d) not match dim (%d)", len(floatArray.GetData()), dim)
}
if err := typeutil.VerifyFloats32(floatArray.GetData()); err != nil {
return fmt.Errorf("float vector field data is illegal, error: %w", err)
}
}
return nil
}
func (it *insertTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute")
defer sp.Finish()
@ -263,10 +229,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err
}
// check vector field data
err = it.checkVectorFieldData()
if err != nil {
log.Error("vector field data is illegal", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err))
if err := newValidateUtil(withNANCheck()).Validate(it.GetFieldsData(), it.schema, it.NRows()); err != nil {
return err
}

View File

@ -1,8 +1,6 @@
package proxy
import (
"math"
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -347,92 +345,3 @@ func TestInsertTask_CheckAligned(t *testing.T) {
err = case2.CheckAligned()
assert.NoError(t, err)
}
func TestInsertTask_CheckVectorFieldData(t *testing.T) {
fieldName := "embeddings"
numRows := 10
dim := 32
task := insertTask{
BaseInsertTask: BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: internalpb.InsertDataVersion_ColumnBased,
NumRows: uint64(numRows),
},
},
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_CheckVectorFieldData",
Description: "TestInsertTask_CheckVectorFieldData",
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
AutoID: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: strconv.Itoa(dim)}},
},
},
},
}
// success case
task.FieldsData = []*schemapb.FieldData{
newFloatVectorFieldData(fieldName, numRows, dim),
}
err := task.checkVectorFieldData()
assert.NoError(t, err)
// field is nil
task.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: nil,
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector data is not a number
values := generateFloatVectors(numRows, dim)
values[5] = float32(math.NaN())
task.FieldsData[0].Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: values,
},
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector data is infinity
values[5] = float32(math.Inf(1))
task.FieldsData[0].Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: values,
},
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector dim not match
task.FieldsData = []*schemapb.FieldData{
newFloatVectorFieldData(fieldName, numRows, dim+1),
}
err = task.checkVectorFieldData()
assert.Error(t, err)
}

View File

@ -0,0 +1,195 @@
package proxy
import (
"errors"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type validateUtil struct {
checkNAN bool
checkMaxLen bool
}
type validateOption func(*validateUtil)
func withNANCheck() validateOption {
return func(v *validateUtil) {
v.checkNAN = true
}
}
func withMaxLenCheck() validateOption {
return func(v *validateUtil) {
v.checkMaxLen = true
}
}
func (v *validateUtil) apply(opts ...validateOption) {
for _, opt := range opts {
opt(v)
}
}
func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.CollectionSchema, numRows uint64) error {
helper, err := typeutil.CreateSchemaHelper(schema)
if err != nil {
return err
}
if err := v.checkAligned(data, helper, numRows); err != nil {
return err
}
for _, field := range data {
fieldSchema, err := helper.GetFieldFromName(field.GetFieldName())
if err != nil {
return err
}
switch fieldSchema.GetDataType() {
case schemapb.DataType_FloatVector:
if err := v.checkFloatVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_BinaryVector:
if err := v.checkBinaryVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_VarChar:
if err := v.checkVarCharFieldData(field, fieldSchema); err != nil {
return err
}
default:
}
}
return nil
}
func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil.SchemaHelper, numRows uint64) error {
errNumRowsMismatch := func(fieldName string, fieldNumRows, passedNumRows uint64) error {
msg := fmt.Sprintf("the num_rows (%d) of field (%s) is not equal to passed num_rows (%d)", fieldNumRows, fieldName, passedNumRows)
return errors.New(msg)
}
for _, field := range data {
switch field.GetType() {
case schemapb.DataType_FloatVector:
f, err := schema.GetFieldFromName(field.GetFieldName())
if err != nil {
return err
}
dim, err := typeutil.GetDim(f)
if err != nil {
return err
}
n, err := funcutil.GetNumRowsOfFloatVectorField(field.GetVectors().GetFloatVector().GetData(), dim)
if err != nil {
return err
}
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n, numRows)
}
case schemapb.DataType_BinaryVector:
f, err := schema.GetFieldFromName(field.GetFieldName())
if err != nil {
return err
}
dim, err := typeutil.GetDim(f)
if err != nil {
return err
}
n, err := funcutil.GetNumRowsOfBinaryVectorField(field.GetVectors().GetBinaryVector(), dim)
if err != nil {
return err
}
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n, numRows)
}
default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldData(field)
if err != nil {
return err
}
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n, numRows)
}
}
}
return nil
}
func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
floatArray := field.GetVectors().GetFloatVector().GetData()
if floatArray == nil {
msg := fmt.Sprintf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return errors.New(msg)
}
if v.checkNAN {
return typeutil.VerifyFloats32(floatArray)
}
return nil
}
func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
// TODO
return nil
}
func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil {
msg := fmt.Sprintf("varchar field '%v' is illegal, array type mismatch", field.GetFieldName())
return errors.New(msg)
}
if v.checkMaxLen {
maxLength, err := GetMaxLength(fieldSchema)
if err != nil {
return err
}
return verifyLengthPerRow(strArr, maxLength)
}
return nil
}
func verifyLengthPerRow(strArr []string, maxLength int64) error {
for i, s := range strArr {
if int64(len(s)) > maxLength {
msg := fmt.Sprintf("the length (%d) of %dth string exceeds max length (%d)", len(s), i, maxLength)
return errors.New(msg)
}
}
return nil
}
func newValidateUtil(opts ...validateOption) *validateUtil {
v := &validateUtil{
checkNAN: true,
checkMaxLen: false,
}
v.apply(opts...)
return v
}

View File

@ -0,0 +1,912 @@
package proxy
import (
"math"
"testing"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_verifyLengthPerRow(t *testing.T) {
maxLength := 16
assert.NoError(t, verifyLengthPerRow(nil, int64(maxLength)))
assert.NoError(t, verifyLengthPerRow([]string{"111111", "22222"}, int64(maxLength)))
assert.Error(t, verifyLengthPerRow([]string{"11111111111111111"}, int64(maxLength)))
assert.Error(t, verifyLengthPerRow([]string{"11111111111111111", "222"}, int64(maxLength)))
assert.Error(t, verifyLengthPerRow([]string{"11111", "22222222222222222"}, int64(maxLength)))
}
func Test_validateUtil_checkVarCharFieldData(t *testing.T) {
t.Run("type mismatch", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
assert.Error(t, v.checkVarCharFieldData(f, nil))
})
t.Run("max length not found", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
}
fs := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
}
v := newValidateUtil(withMaxLenCheck())
err := v.checkVarCharFieldData(f, fs)
assert.Error(t, err)
})
t.Run("length exceeds", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
}
fs := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "2",
},
},
}
v := newValidateUtil(withMaxLenCheck())
err := v.checkVarCharFieldData(f, fs)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
}
fs := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "4",
},
},
}
v := newValidateUtil(withMaxLenCheck())
err := v.checkVarCharFieldData(f, fs)
assert.NoError(t, err)
})
t.Run("no check", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
}
fs := &schemapb.FieldSchema{
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "2",
},
},
}
v := newValidateUtil()
err := v.checkVarCharFieldData(f, fs)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) {
assert.NoError(t, newValidateUtil().checkBinaryVectorFieldData(nil, nil))
}
func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
t.Run("not float vector", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
err := v.checkFloatVectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("no check", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2},
},
},
},
},
}
v := newValidateUtil()
v.checkNAN = false
err := v.checkFloatVectorFieldData(f, nil)
assert.NoError(t, err)
})
t.Run("has nan", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{float32(math.NaN())},
},
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkFloatVectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2},
},
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkFloatVectorFieldData(f, nil)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkAligned(t *testing.T) {
t.Run("float vector column not found", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_FloatVector,
},
}
schema := &schemapb.CollectionSchema{}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("float vector column dimension not found", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_FloatVector,
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_FloatVector,
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("invalid num rows", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("num rows mismatch", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
//////////////////////////////////////////////////////////////////////
t.Run("binary vector column not found", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_BinaryVector,
},
}
schema := &schemapb.CollectionSchema{}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("binary vector column dimension not found", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_BinaryVector,
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_BinaryVector,
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("invalid num rows", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte("not128"),
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
t.Run("num rows mismatch", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte{'1', '2'},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
//////////////////////////////////////////////////////////////////
t.Run("mismatch", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "8",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 100)
assert.Error(t, err)
})
/////////////////////////////////////////////////////////////////////
t.Run("normal case", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test1",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(10, 8),
},
},
},
},
},
{
FieldName: "test2",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(10, 8),
},
},
},
},
{
FieldName: "test3",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: generateVarCharArray(10, 8),
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test1",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test2",
FieldID: 102,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test3",
FieldID: 103,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "8",
},
},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.checkAligned(data, h, 10)
assert.NoError(t, err)
})
}
func Test_validateUtil_Validate(t *testing.T) {
t.Run("nil schema", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_FloatVector,
},
}
v := newValidateUtil()
err := v.Validate(data, nil, 100)
assert.Error(t, err)
})
t.Run("not aligned", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"111", "222"},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "8",
},
},
},
},
}
v := newValidateUtil()
err := v.Validate(data, schema, 100)
assert.Error(t, err)
})
t.Run("has nan", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test1",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{float32(math.NaN()), float32(math.NaN())},
},
},
},
},
},
{
FieldName: "test2",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(2, 8),
},
},
},
},
{
FieldName: "test3",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: generateVarCharArray(2, 8),
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test1",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
},
{
Name: "test2",
FieldID: 102,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test3",
FieldID: 103,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "8",
},
},
},
},
}
v := newValidateUtil(withNANCheck(), withMaxLenCheck())
err := v.Validate(data, schema, 2)
assert.Error(t, err)
})
t.Run("length exceeds", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test1",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(2, 1),
},
},
},
},
},
{
FieldName: "test2",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(2, 8),
},
},
},
},
{
FieldName: "test3",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"very_long", "very_very_long"},
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test1",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
},
{
Name: "test2",
FieldID: 102,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test3",
FieldID: 103,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "2",
},
},
},
},
}
v := newValidateUtil(withNANCheck(), withMaxLenCheck())
err := v.Validate(data, schema, 2)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldName: "test1",
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(10, 8),
},
},
},
},
},
{
FieldName: "test2",
Type: schemapb.DataType_BinaryVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(10, 8),
},
},
},
},
{
FieldName: "test3",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: generateVarCharArray(10, 8),
},
},
},
},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: "test1",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test2",
FieldID: 102,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "8",
},
},
},
{
Name: "test3",
FieldID: 103,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "8",
},
},
},
},
}
v := newValidateUtil(withNANCheck(), withMaxLenCheck())
err := v.Validate(data, schema, 10)
assert.NoError(t, err)
})
}

View File

@ -275,7 +275,7 @@ func getNumRowsOfScalarField(datas interface{}) uint64 {
return uint64(realTypeDatas.Len())
}
func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint64, error) {
func GetNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint64, error) {
if dim <= 0 {
return 0, fmt.Errorf("dim(%d) should be greater than 0", dim)
}
@ -286,7 +286,7 @@ func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint64, error) {
return uint64(int64(l) / dim), nil
}
func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) {
func GetNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) {
if dim <= 0 {
return 0, fmt.Errorf("dim(%d) should be greater than 0", dim)
}
@ -328,13 +328,13 @@ func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) {
switch vectorFieldType := vectorField.Data.(type) {
case *schemapb.VectorField_FloatVector:
dim := vectorField.GetDim()
fieldNumRows, err = getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim)
fieldNumRows, err = GetNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim)
if err != nil {
return 0, err
}
case *schemapb.VectorField_BinaryVector:
dim := vectorField.GetDim()
fieldNumRows, err = getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim)
fieldNumRows, err = GetNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim)
if err != nil {
return 0, err
}

View File

@ -359,11 +359,11 @@ func TestGetNumRowsOfFloatVectorField(t *testing.T) {
}
for _, test := range cases {
got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim)
got, err := GetNumRowsOfFloatVectorField(test.fDatas, test.dim)
if test.errIsNil {
assert.Equal(t, nil, err)
if got != test.want {
t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil)
t.Errorf("GetNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil)
}
} else {
assert.NotEqual(t, nil, err)
@ -392,11 +392,11 @@ func TestGetNumRowsOfBinaryVectorField(t *testing.T) {
}
for _, test := range cases {
got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim)
got, err := GetNumRowsOfBinaryVectorField(test.bDatas, test.dim)
if test.errIsNil {
assert.Equal(t, nil, err)
if got != test.want {
t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil)
t.Errorf("GetNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil)
}
} else {
assert.NotEqual(t, nil, err)