Add test cases for create collection (#7759)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/7766/head
dragondriver 2021-09-11 18:06:02 +08:00 committed by GitHub
parent ab7ded2202
commit 5a3fc37584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 353 additions and 14 deletions

View File

@ -0,0 +1,269 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"context"
"strconv"
"testing"
"time"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"github.com/stretchr/testify/assert"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
)
func TestCreateCollectionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
ctx := context.Background()
shardsNum := int32(2)
prefix := "TestCreateCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
var marshaledSchema []byte
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
t.Run("on enqueue", func(t *testing.T) {
err := task.OnEnqueue()
assert.NoError(t, err)
assert.Equal(t, commonpb.MsgType_CreateCollection, task.Type())
})
t.Run("ctx", func(t *testing.T) {
traceCtx := task.TraceCtx()
assert.NotNil(t, traceCtx)
})
t.Run("id", func(t *testing.T) {
id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
task.SetID(id)
assert.Equal(t, id, task.ID())
})
t.Run("name", func(t *testing.T) {
assert.Equal(t, CreateCollectionTaskName, task.Name())
})
t.Run("ts", func(t *testing.T) {
ts := Timestamp(time.Now().UnixNano())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
})
t.Run("process task", func(t *testing.T) {
var err error
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
// recreate -> fail
err = task.Execute(ctx)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
err = task.PostExecute(ctx)
assert.NoError(t, err)
})
t.Run("PreExecute", func(t *testing.T) {
var err error
err = task.PreExecute(ctx)
assert.NoError(t, err)
task.Schema = []byte{0x1, 0x2, 0x3, 0x4}
err = task.PreExecute(ctx)
assert.Error(t, err)
task.Schema = marshaledSchema
task.ShardsNum = Params.MaxShardNum + 1
err = task.PreExecute(ctx)
assert.Error(t, err)
task.ShardsNum = shardsNum
reqBackup := proto.Clone(task.CreateCollectionRequest).(*milvuspb.CreateCollectionRequest)
schemaBackup := proto.Clone(schema).(*schemapb.CollectionSchema)
schemaWithTooManyFields := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: make([]*schemapb.FieldSchema, Params.MaxFieldNum+1),
}
marshaledSchemaWithTooManyFields, err := proto.Marshal(schemaWithTooManyFields)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = marshaledSchemaWithTooManyFields
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CreateCollectionRequest = reqBackup
// ValidateCollectionName
schema.Name = " " // empty
emptyNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = emptyNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema.Name = prefix
for i := 0; i < int(Params.MaxNameLength); i++ {
schema.Name += strconv.Itoa(i % 10)
}
tooLongNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = tooLongNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema.Name = "$" // invalid first char
invalidFirstCharSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = invalidFirstCharSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateDuplicatedFieldName
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields = append(schema.Fields, schema.Fields[0])
duplicatedFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = duplicatedFieldsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidatePrimaryKey
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
schema.Fields[idx].IsPrimaryKey = false
}
noPrimaryFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noPrimaryFieldsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateFieldName
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
schema.Fields[idx].Name = "$"
}
invalidFieldNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = invalidFieldNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateVectorField
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = nil
}
}
noDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "not int",
},
}
}
}
dimNotIntSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = dimNotIntSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(int(Params.MaxDimension) + 1),
},
}
}
}
tooLargeDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = tooLargeDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields[1].DataType = schemapb.DataType_BinaryVector
schema.Fields[1].TypeParams = []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(int(Params.MaxDimension) + 1),
},
}
binaryTooLargeDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
})
}

View File

@ -48,12 +48,12 @@ import (
)
const (
InsertTaskName = "insertTask"
InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask"
SearchTaskName = "searchTask"
SearchTaskName = "SearchTask"
RetrieveTaskName = "RetrieveTask"
QueryTaskName = "queryTask"
QueryTaskName = "QueryTask"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
@ -67,12 +67,12 @@ const (
DropPartitionTaskName = "DropPartitionTask"
HasPartitionTaskName = "HasPartitionTask"
ShowPartitionTaskName = "ShowPartitionTask"
CreateIndexTaskName = "createIndexTask"
DescribeIndexTaskName = "describeIndexTask"
DropIndexTaskName = "dropIndexTask"
GetIndexStateTaskName = "getIndexStateTask"
GetIndexBuildProgressTaskName = "getIndexBuildProgressTask"
FlushTaskName = "flushTask"
CreateIndexTaskName = "CreateIndexTask"
DescribeIndexTaskName = "DescribeIndexTask"
DropIndexTaskName = "DropIndexTask"
GetIndexStateTaskName = "GetIndexStateTask"
GetIndexBuildProgressTaskName = "GetIndexBuildProgressTask"
FlushTaskName = "FlushTask"
LoadCollectionTaskName = "LoadCollectionTask"
ReleaseCollectionTaskName = "ReleaseCollectionTask"
LoadPartitionTaskName = "LoadPartitionsTask"
@ -1112,6 +1112,8 @@ func (cct *createCollectionTask) SetTs(ts Timestamp) {
func (cct *createCollectionTask) OnEnqueue() error {
cct.Base = &commonpb.MsgBase{}
cct.Base.MsgType = commonpb.MsgType_CreateCollection
cct.Base.SourceID = Params.ProxyID
return nil
}
@ -1121,8 +1123,11 @@ func (cct *createCollectionTask) PreExecute(ctx context.Context) error {
cct.schema = &schemapb.CollectionSchema{}
err := proto.Unmarshal(cct.Schema, cct.schema)
if err != nil {
return err
}
cct.schema.AutoID = false
cct.CreateCollectionRequest.Schema, _ = proto.Marshal(cct.schema)
cct.CreateCollectionRequest.Schema, err = proto.Marshal(cct.schema)
if err != nil {
return err
}

View File

@ -1,8 +1,13 @@
package proxy
import (
"fmt"
"strconv"
"testing"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
@ -12,6 +17,70 @@ import (
// TODO(dragondriver): add more test cases
func constructCollectionSchema(
int64Field, floatVecField string,
dim int,
collectionName string,
) *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
func constructCreateCollectionRequest(
schema *schemapb.CollectionSchema,
dbName, collectionName string,
shardsNum int32,
) *milvuspb.CreateCollectionRequest {
bs, err := proto.Marshal(schema)
if err != nil {
panic(
fmt.Sprintf(
"failed to marshal collection schema, schema: %v, error: %v",
schema,
err))
}
return &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: bs,
ShardsNum: shardsNum,
}
}
func TestGetNumRowsOfScalarField(t *testing.T) {
cases := []struct {
datas interface{}
@ -508,7 +577,3 @@ func TestTranslateOutputFields(t *testing.T) {
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
}
func TestCreateCollectionTask(t *testing.T) {
}