mirror of https://github.com/milvus-io/milvus.git
Fix 6419, check if num_rows is greater than zero in proxy (#6439)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/6485/head
parent
eac2b8a8f9
commit
d896e3119e
|
@ -0,0 +1,48 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// TODO(dragondriver): add more common error type
|
||||
|
||||
func errInvalidNumRows(numRows uint32) error {
|
||||
return fmt.Errorf("invalid num_rows: %d", numRows)
|
||||
}
|
||||
|
||||
func errNumRowsLessThanOrEqualToZero(numRows uint32) error {
|
||||
return fmt.Errorf("num_rows(%d) should be greater than 0", numRows)
|
||||
}
|
||||
|
||||
func errNumRowsOfFieldDataMismatchPassed(idx int, fieldNumRows, passedNumRows uint32) error {
|
||||
return fmt.Errorf("the num_rows(%d) of %dth field is not equal to passed NumRows(%d)", fieldNumRows, idx, passedNumRows)
|
||||
}
|
||||
|
||||
var errEmptyFieldData = errors.New("empty field data")
|
||||
|
||||
func errFieldsLessThanNeeded(fieldsNum, needed int) error {
|
||||
return fmt.Errorf("the length(%d) of passed fields is less than needed(%d)", fieldsNum, needed)
|
||||
}
|
||||
|
||||
func errUnsupportedDataType(dType schemapb.DataType) error {
|
||||
return fmt.Errorf("%v is not supported now", dType)
|
||||
}
|
||||
|
||||
func errUnsupportedDType(dType string) error {
|
||||
return fmt.Errorf("%s is not supported now", dType)
|
||||
}
|
||||
|
||||
func errInvalidDim(dim int) error {
|
||||
return fmt.Errorf("invalid dim: %d", dim)
|
||||
}
|
||||
|
||||
func errDimLessThanOrEqualToZero(dim int) error {
|
||||
return fmt.Errorf("dim(%d) should be greater than 0", dim)
|
||||
}
|
||||
|
||||
func errDimShouldDivide8(dim int) error {
|
||||
return fmt.Errorf("dim(%d) should divide 8", dim)
|
||||
}
|
|
@ -197,6 +197,133 @@ func (it *InsertTask) OnEnqueue() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getNumRowsOfScalarField(datas interface{}) uint32 {
|
||||
realTypeDatas := reflect.ValueOf(datas)
|
||||
return uint32(realTypeDatas.Len())
|
||||
}
|
||||
|
||||
func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint32, error) {
|
||||
if dim <= 0 {
|
||||
return 0, errDimLessThanOrEqualToZero(int(dim))
|
||||
}
|
||||
l := len(fDatas)
|
||||
if int64(l)%dim != 0 {
|
||||
return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim)
|
||||
}
|
||||
return uint32(int(int64(l) / dim)), nil
|
||||
}
|
||||
|
||||
func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint32, error) {
|
||||
if dim <= 0 {
|
||||
return 0, errDimLessThanOrEqualToZero(int(dim))
|
||||
}
|
||||
if dim%8 != 0 {
|
||||
return 0, errDimShouldDivide8(int(dim))
|
||||
}
|
||||
l := len(bDatas)
|
||||
if (8*int64(l))%dim != 0 {
|
||||
return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim)
|
||||
}
|
||||
return uint32(int((8 * int64(l)) / dim)), nil
|
||||
}
|
||||
|
||||
func (it *InsertTask) checkLengthOfFieldsData() error {
|
||||
neededFieldsNum := 0
|
||||
for _, field := range it.schema.Fields {
|
||||
if !field.AutoID {
|
||||
neededFieldsNum++
|
||||
}
|
||||
}
|
||||
|
||||
if len(it.req.FieldsData) < neededFieldsNum {
|
||||
return errFieldsLessThanNeeded(len(it.req.FieldsData), neededFieldsNum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *InsertTask) checkRowNums() error {
|
||||
if it.req.NumRows <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(it.req.NumRows)
|
||||
}
|
||||
|
||||
if err := it.checkLengthOfFieldsData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowNums := it.req.NumRows
|
||||
|
||||
for i, field := range it.req.FieldsData {
|
||||
switch field.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := field.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetBoolData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetIntData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetLongData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetFloatData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetDoubleData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_BytesData:
|
||||
return errUnsupportedDType("bytes")
|
||||
case *schemapb.ScalarField_StringData:
|
||||
return errUnsupportedDType("string")
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
vectorField := field.GetVectors()
|
||||
switch vectorField.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err := getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err := getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(dragondriver): ignore the order of fields in request, use the order of CollectionSchema to reorganize data
|
||||
func (it *InsertTask) transferColumnBasedRequestToRowBasedData() error {
|
||||
dTypes := make([]schemapb.DataType, 0, len(it.req.FieldsData))
|
||||
|
@ -441,11 +568,16 @@ func (it *InsertTask) transferColumnBasedRequestToRowBasedData() error {
|
|||
|
||||
func (it *InsertTask) checkFieldAutoID() error {
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
rowNums := it.req.NumRows
|
||||
if len(it.req.FieldsData) == 0 || rowNums == 0 {
|
||||
return fmt.Errorf("do not contain any data")
|
||||
if it.req.NumRows <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(it.req.NumRows)
|
||||
}
|
||||
|
||||
if err := it.checkLengthOfFieldsData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowNums := it.req.NumRows
|
||||
|
||||
primaryFieldName := ""
|
||||
autoIDFieldName := ""
|
||||
autoIDLoc := -1
|
||||
|
@ -611,6 +743,11 @@ func (it *InsertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
it.schema = collSchema
|
||||
|
||||
err = it.checkRowNums()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = it.checkFieldAutoID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -0,0 +1,388 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetNumRowsOfScalarField(t *testing.T) {
|
||||
cases := []struct {
|
||||
datas interface{}
|
||||
want uint32
|
||||
}{
|
||||
{[]bool{}, 0},
|
||||
{[]bool{true, false}, 2},
|
||||
{[]int32{}, 0},
|
||||
{[]int32{1, 2}, 2},
|
||||
{[]int64{}, 0},
|
||||
{[]int64{1, 2}, 2},
|
||||
{[]float32{}, 0},
|
||||
{[]float32{1.0, 2.0}, 2},
|
||||
{[]float64{}, 0},
|
||||
{[]float64{1.0, 2.0}, 2},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := getNumRowsOfScalarField(test.datas); got != test.want {
|
||||
t.Errorf("getNumRowsOfScalarField(%v) = %v", test.datas, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfFloatVectorField(t *testing.T) {
|
||||
cases := []struct {
|
||||
fDatas []float32
|
||||
dim int64
|
||||
want uint32
|
||||
errIsNil bool
|
||||
}{
|
||||
{[]float32{}, -1, 0, false}, // dim <= 0
|
||||
{[]float32{}, 0, 0, false}, // dim <= 0
|
||||
{[]float32{1.0}, 128, 0, false}, // length % dim != 0
|
||||
{[]float32{}, 128, 0, true},
|
||||
{[]float32{1.0, 2.0}, 2, 1, true},
|
||||
{[]float32{1.0, 2.0, 3.0, 4.0}, 2, 2, true},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim)
|
||||
if test.errIsNil {
|
||||
assert.Equal(t, nil, err)
|
||||
if got != test.want {
|
||||
t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil)
|
||||
}
|
||||
} else {
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfBinaryVectorField(t *testing.T) {
|
||||
cases := []struct {
|
||||
bDatas []byte
|
||||
dim int64
|
||||
want uint32
|
||||
errIsNil bool
|
||||
}{
|
||||
{[]byte{}, -1, 0, false}, // dim <= 0
|
||||
{[]byte{}, 0, 0, false}, // dim <= 0
|
||||
{[]byte{1.0}, 128, 0, false}, // length % dim != 0
|
||||
{[]byte{}, 128, 0, true},
|
||||
{[]byte{1.0}, 1, 0, false}, // dim % 8 != 0
|
||||
{[]byte{1.0}, 4, 0, false}, // dim % 8 != 0
|
||||
{[]byte{1.0, 2.0}, 8, 2, true},
|
||||
{[]byte{1.0, 2.0}, 16, 1, true},
|
||||
{[]byte{1.0, 2.0, 3.0, 4.0}, 8, 4, true},
|
||||
{[]byte{1.0, 2.0, 3.0, 4.0}, 16, 2, true},
|
||||
{[]byte{1.0}, 128, 0, false}, // (8*l) % dim != 0
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim)
|
||||
if test.errIsNil {
|
||||
assert.Equal(t, nil, err)
|
||||
if got != test.want {
|
||||
t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil)
|
||||
}
|
||||
} else {
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertTask_checkLengthOfFieldsData(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// schema is empty, though won't happened in system
|
||||
case1 := InsertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
},
|
||||
req: &milvuspb.InsertRequest{
|
||||
DbName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
CollectionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
PartitionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
FieldsData: nil,
|
||||
},
|
||||
}
|
||||
err = case1.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two fields, neither of them are autoID
|
||||
case2 := InsertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
case2.req = &milvuspb.InsertRequest{}
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// the num of passed fields is less than needed
|
||||
case2.req = &milvuspb.InsertRequest{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case2.req = &milvuspb.InsertRequest{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two field, one of them are autoID
|
||||
case3 := InsertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
case3.req = &milvuspb.InsertRequest{}
|
||||
err = case3.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case3.req = &milvuspb.InsertRequest{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = case3.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has one field which is autoID
|
||||
case4 := InsertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// satisfied
|
||||
case4.req = &milvuspb.InsertRequest{}
|
||||
err = case4.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertTask_checkRowNums(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// passed NumRows is less than 0
|
||||
case1 := InsertTask{
|
||||
req: &milvuspb.InsertRequest{
|
||||
NumRows: 0,
|
||||
},
|
||||
}
|
||||
err = case1.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData
|
||||
|
||||
numRows := 20
|
||||
dim := 128
|
||||
case2 := InsertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkRowNums",
|
||||
Description: "TestInsertTask_checkRowNums",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{DataType: schemapb.DataType_Bool},
|
||||
{DataType: schemapb.DataType_Int8},
|
||||
{DataType: schemapb.DataType_Int16},
|
||||
{DataType: schemapb.DataType_Int32},
|
||||
{DataType: schemapb.DataType_Int64},
|
||||
{DataType: schemapb.DataType_Float},
|
||||
{DataType: schemapb.DataType_Double},
|
||||
{DataType: schemapb.DataType_FloatVector},
|
||||
{DataType: schemapb.DataType_BinaryVector},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// satisfied
|
||||
case2.req = &milvuspb.InsertRequest{
|
||||
NumRows: uint32(numRows),
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Float, "Float", numRows),
|
||||
newScalarFieldData(schemapb.DataType_Double, "Double", numRows),
|
||||
newFloatVectorFieldData("FloatVector", numRows, dim),
|
||||
newBinaryVectorFieldData("BinaryVector", numRows, dim),
|
||||
},
|
||||
}
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less bool data
|
||||
case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more bool data
|
||||
case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less int8 data
|
||||
case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more int8 data
|
||||
case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less int16 data
|
||||
case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more int16 data
|
||||
case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less int32 data
|
||||
case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more int32 data
|
||||
case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less int64 data
|
||||
case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more int64 data
|
||||
case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less float data
|
||||
case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more float data
|
||||
case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less double data
|
||||
case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows/2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more double data
|
||||
case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows*2)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less float vectors
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more float vectors
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// less binary vectors
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// more binary vectors
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// revert
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
||||
err = case2.checkRowNums()
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func generateBoolArray(numRows int) []bool {
|
||||
ret := make([]bool, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Int()%2 == 0)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateInt8Array(numRows int) []int8 {
|
||||
ret := make([]int8, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, int8(rand.Int()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateInt16Array(numRows int) []int16 {
|
||||
ret := make([]int16, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, int16(rand.Int()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateInt32Array(numRows int) []int32 {
|
||||
ret := make([]int32, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, int32(rand.Int()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateInt64Array(numRows int) []int64 {
|
||||
ret := make([]int64, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, int64(rand.Int()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFloat32Array(numRows int) []float32 {
|
||||
ret := make([]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Float32())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFloat64Array(numRows int) []float64 {
|
||||
ret := make([]float64, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Float64())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFloatVectors(numRows, dim int) []float32 {
|
||||
total := numRows * dim
|
||||
ret := make([]float32, 0, total)
|
||||
for i := 0; i < total; i++ {
|
||||
ret = append(ret, rand.Float32())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateBinaryVectors(numRows, dim int) []byte {
|
||||
total := (numRows * dim) / 8
|
||||
ret := make([]byte, total)
|
||||
_, err := rand.Read(ret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int) *schemapb.FieldData {
|
||||
ret := &schemapb.FieldData{
|
||||
Type: dType,
|
||||
FieldName: fieldName,
|
||||
Field: nil,
|
||||
}
|
||||
|
||||
switch dType {
|
||||
case schemapb.DataType_Bool:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: generateBoolArray(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: generateInt32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: generateInt32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: generateInt32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: generateInt64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: generateFloat32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: generateFloat64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: generateFloatVectors(numRows, dim),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: generateBinaryVectors(numRows, dim),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue