From 0e29c37499a6a68502115a4964ad00dd8863a64d Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Fri, 17 Jun 2022 14:10:11 +0800 Subject: [PATCH] Forbid system fields in user schema (#17613) Signed-off-by: longjiquan --- internal/rootcoord/task.go | 14 ++++++++++++ internal/rootcoord/task_test.go | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/internal/rootcoord/task.go b/internal/rootcoord/task.go index ded6389ed7..3e4b676835 100644 --- a/internal/rootcoord/task.go +++ b/internal/rootcoord/task.go @@ -87,6 +87,15 @@ func (t *CreateCollectionReqTask) Type() commonpb.MsgType { return t.Req.Base.MsgType } +func hasSystemFields(schema *schemapb.CollectionSchema, systemFields []string) bool { + for _, f := range schema.GetFields() { + if funcutil.SliceContain(systemFields, f.GetName()) { + return true + } + } + return false +} + // Execute task execution func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { if t.Type() != commonpb.MsgType_CreateCollection { @@ -108,6 +117,11 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { zap.Int32("ShardsNum", t.Req.ShardsNum), zap.String("ConsistencyLevel", t.Req.ConsistencyLevel.String())) + if hasSystemFields(&schema, []string{RowIDFieldName, TimeStampFieldName}) { + log.Error("failed to create collection, user schema contain system field") + return fmt.Errorf("schema contains system field: %s, %s", RowIDFieldName, TimeStampFieldName) + } + for idx, field := range schema.Fields { field.FieldID = int64(idx + StartOfUserFieldID) } diff --git a/internal/rootcoord/task_test.go b/internal/rootcoord/task_test.go index a685482cd3..39e6a633a2 100644 --- a/internal/rootcoord/task_test.go +++ b/internal/rootcoord/task_test.go @@ -5,6 +5,12 @@ import ( "errors" "testing" + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus/internal/proto/milvuspb" + + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/proto/etcdpb" @@ -113,3 +119,35 @@ func TestDescribeSegmentsReqTask_Execute(t *testing.T) { } assert.NoError(t, tsk.Execute(context.Background())) } + +func Test_hasSystemFields(t *testing.T) { + t.Run("no system fields", func(t *testing.T) { + schema := &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{{Name: "not_system_field"}}} + assert.False(t, hasSystemFields(schema, []string{RowIDFieldName, TimeStampFieldName})) + }) + + t.Run("has row id field", func(t *testing.T) { + schema := &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{{Name: RowIDFieldName}}} + assert.True(t, hasSystemFields(schema, []string{RowIDFieldName, TimeStampFieldName})) + }) + + t.Run("has timestamp field", func(t *testing.T) { + schema := &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{{Name: TimeStampFieldName}}} + assert.True(t, hasSystemFields(schema, []string{RowIDFieldName, TimeStampFieldName})) + }) +} + +func TestCreateCollectionReqTask_Execute_hasSystemFields(t *testing.T) { + schema := &schemapb.CollectionSchema{Name: "test", Fields: []*schemapb.FieldSchema{{Name: TimeStampFieldName}}} + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task := &CreateCollectionReqTask{ + Req: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + CollectionName: "test", + Schema: marshaledSchema, + }, + } + err = task.Execute(context.Background()) + assert.Error(t, err) +}