diff --git a/cmd/master/main.go b/cmd/master/main.go index d05fd102e5..ff3ca927c7 100644 --- a/cmd/master/main.go +++ b/cmd/master/main.go @@ -6,6 +6,7 @@ import ( "log" "os" "os/signal" + "strconv" "syscall" "github.com/zilliztech/milvus-distributed/internal/conf" @@ -24,7 +25,11 @@ func main() { // Creates server. ctx, cancel := context.WithCancel(context.Background()) - svr, err := master.CreateServer(ctx) + etcdAddr := conf.Config.Etcd.Address + etcdAddr += ":" + etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) + + svr, err := master.CreateServer(ctx, conf.Config.Etcd.Rootpath, conf.Config.Etcd.Rootpath, conf.Config.Etcd.Rootpath, []string{etcdAddr}) if err != nil { log.Print("create server failed", zap.Error(err)) } @@ -42,7 +47,9 @@ func main() { cancel() }() - if err := svr.Run(); err != nil { + grpcPort := int64(conf.Config.Master.Port) + + if err := svr.Run(grpcPort); err != nil { log.Fatal("run server failed", zap.Error(err)) } diff --git a/internal/master/README.md b/internal/master/README.md index 90493cef77..343cef711d 100644 --- a/internal/master/README.md +++ b/internal/master/README.md @@ -22,3 +22,4 @@ go run cmd/master.go ### example if master create a collection with uuid ```46e468ee-b34a-419d-85ed-80c56bfa4e90``` the corresponding key in etcd is $(ETCD_ROOT_PATH)/collection/46e468ee-b34a-419d-85ed-80c56bfa4e90 + diff --git a/internal/master/collection_task.go b/internal/master/collection_task.go index e6207e41a2..afa9f260fe 100644 --- a/internal/master/collection_task.go +++ b/internal/master/collection_task.go @@ -1,11 +1,11 @@ package master import ( - "encoding/json" "errors" "log" "strconv" + "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -60,53 +60,41 @@ func (t *createCollectionTask) Ts() (Timestamp, error) { if t.req == nil { return 0, errors.New("null request") } - return Timestamp(t.req.Timestamp), nil + return t.req.Timestamp, nil } func (t *createCollectionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } var schema schemapb.CollectionSchema - err := json.Unmarshal(t.req.Schema.Value, &schema) + err := proto.UnmarshalMerge(t.req.Schema.Value, &schema) if err != nil { - _ = t.Notify() - return errors.New("unmarshal CollectionSchema failed") + return err } - // TODO: allocate collection id - var collectionId UniqueID = 0 - // TODO: allocate timestamp - var collectionCreateTime Timestamp = 0 + collectionId, err := allocGlobalId() + if err != nil { + return err + } + + ts, err := t.Ts() + if err != nil { + return err + } collection := etcdpb.CollectionMeta{ Id: collectionId, Schema: &schema, - CreateTime: collectionCreateTime, + CreateTime: ts, // TODO: initial segment? SegmentIds: make([]UniqueID, 0), // TODO: initial partition? PartitionTags: make([]string, 0), } - collectionJson, err := json.Marshal(&collection) - if err != nil { - _ = t.Notify() - return errors.New("marshal collection failed") - } - - err = (*t.kvBase).Save(collectionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) - if err != nil { - _ = t.Notify() - return errors.New("save collection failed") - } - - t.mt.collId2Meta[collectionId] = collection - - _ = t.Notify() - return nil + return t.mt.AddCollection(&collection) } ////////////////////////////////////////////////////////////////////////// @@ -127,14 +115,12 @@ func (t *dropCollectionTask) Ts() (Timestamp, error) { func (t *dropCollectionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } collectionName := t.req.CollectionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { - _ = t.Notify() return err } @@ -142,13 +128,11 @@ func (t *dropCollectionTask) Execute() error { err = (*t.kvBase).Remove(collectionMetaPrefix + strconv.FormatInt(collectionId, 10)) if err != nil { - _ = t.Notify() - return errors.New("save collection failed") + return err } delete(t.mt.collId2Meta, collectionId) - _ = t.Notify() return nil } @@ -170,7 +154,6 @@ func (t *hasCollectionTask) Ts() (Timestamp, error) { func (t *hasCollectionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -180,7 +163,6 @@ func (t *hasCollectionTask) Execute() error { t.hasCollection = true } - _ = t.Notify() return nil } @@ -202,14 +184,12 @@ func (t *describeCollectionTask) Ts() (Timestamp, error) { func (t *describeCollectionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } collectionName := t.req.CollectionName collection, err := t.mt.GetCollectionByName(collectionName.CollectionName) if err != nil { - _ = t.Notify() return err } @@ -222,7 +202,6 @@ func (t *describeCollectionTask) Execute() error { t.description = &description - _ = t.Notify() return nil } @@ -244,7 +223,6 @@ func (t *showCollectionsTask) Ts() (Timestamp, error) { func (t *showCollectionsTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -262,6 +240,5 @@ func (t *showCollectionsTask) Execute() error { t.stringListResponse = &stringListResponse - _ = t.Notify() return nil } diff --git a/internal/master/grpc_service.go b/internal/master/grpc_service.go index 01dabfa128..78e9a6bc81 100644 --- a/internal/master/grpc_service.go +++ b/internal/master/grpc_service.go @@ -19,14 +19,13 @@ func (s *Master) CreateCollection(ctx context.Context, in *internalpb.CreateColl req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { - err := errors.New("Enqueue failed") return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: "Enqueue failed", @@ -35,10 +34,9 @@ func (s *Master) CreateCollection(ctx context.Context, in *internalpb.CreateColl err = t.WaitToFinish(ctx) if err != nil { - err := errors.New("WaitToFinish failed") return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "WaitToFinish failed", + Reason: "create collection failed", }, err } @@ -52,12 +50,12 @@ func (s *Master) DropCollection(ctx context.Context, in *internalpb.DropCollecti req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -85,13 +83,13 @@ func (s *Master) HasCollection(ctx context.Context, in *internalpb.HasCollection req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, hasCollection: false, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return &servicepb.BoolResponse{ @@ -128,13 +126,13 @@ func (s *Master) DescribeCollection(ctx context.Context, in *internalpb.Describe req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, description: nil, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return t.(*describeCollectionTask).description, err @@ -154,13 +152,13 @@ func (s *Master) ShowCollections(ctx context.Context, in *internalpb.ShowCollect req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, stringListResponse: nil, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return t.(*showCollectionsTask).stringListResponse, err @@ -181,12 +179,12 @@ func (s *Master) CreatePartition(ctx context.Context, in *internalpb.CreateParti req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -214,12 +212,12 @@ func (s *Master) DropPartition(ctx context.Context, in *internalpb.DropPartition req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -247,13 +245,13 @@ func (s *Master) HasPartition(ctx context.Context, in *internalpb.HasPartitionRe req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, hasPartition: false, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return &servicepb.BoolResponse{ @@ -290,13 +288,13 @@ func (s *Master) DescribePartition(ctx context.Context, in *internalpb.DescribeP req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, description: nil, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return t.(*describePartitionTask).description, err @@ -316,13 +314,13 @@ func (s *Master) ShowPartitions(ctx context.Context, in *internalpb.ShowPartitio req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: &s.mt, - cv: make(chan int), + mt: s.mt, + cv: make(chan error), }, stringListResponse: nil, } - var err = s.scheduler.Enqueue(&t) + var err = s.scheduler.Enqueue(t) if err != nil { err := errors.New("Enqueue failed") return t.(*showPartitionTask).stringListResponse, err diff --git a/internal/master/grpc_service_test.go b/internal/master/grpc_service_test.go new file mode 100644 index 0000000000..7736b53f75 --- /dev/null +++ b/internal/master/grpc_service_test.go @@ -0,0 +1,140 @@ +package master + +import ( + "context" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "go.etcd.io/etcd/clientv3" + "google.golang.org/grpc" + "testing" +) + +func TestMaster_CreateCollection(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + etcdCli, err := clientv3.New(clientv3.Config{Endpoints: []string{"127.0.0.1:2379"}}) + assert.Nil(t, err) + _, err = etcdCli.Delete(ctx, "/test/root", clientv3.WithPrefix()) + assert.Nil(t, err) + + svr, err := CreateServer(ctx, "/test/root/kv", "/test/root/meta", "/test/root/meta/tso", []string{"127.0.0.1:2379"}) + assert.Nil(t, err) + err = svr.Run(10001) + assert.Nil(t, err) + + conn, err := grpc.DialContext(ctx, "127.0.0.1:10001", grpc.WithInsecure(), grpc.WithBlock()) + assert.Nil(t, err) + defer conn.Close() + + cli := masterpb.NewMasterClient(conn) + sch := schemapb.CollectionSchema{ + Name: "col1", + Description: "test collection", + AutoId: false, + Fields: []*schemapb.FieldSchema{ + { + Name: "col1_f1", + Description: "test collection filed 1", + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "col1_f1_tk1", + Value: "col1_f1_tv1", + }, + { + Key: "col1_f1_tk2", + Value: "col1_f1_tv2", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "col1_f1_ik1", + Value: "col1_f1_iv1", + }, + { + Key: "col1_f1_ik2", + Value: "col1_f1_iv2", + }, + }, + }, + { + Name: "col1_f2", + Description: "test collection filed 2", + DataType: schemapb.DataType_VECTOR_BINARY, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "col1_f2_tk1", + Value: "col1_f2_tv1", + }, + { + Key: "col1_f2_tk2", + Value: "col1_f2_tv2", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "col1_f2_ik1", + Value: "col1_f2_iv1", + }, + { + Key: "col1_f2_ik2", + Value: "col1_f2_iv2", + }, + }, + }, + }, + } + schema_bytes, err := proto.Marshal(&sch) + assert.Nil(t, err) + + req := internalpb.CreateCollectionRequest{ + MsgType: internalpb.MsgType_kCreateCollection, + ReqId: 1, + Timestamp: 11, + ProxyId: 1, + Schema: &commonpb.Blob{Value: schema_bytes}, + } + st, err := cli.CreateCollection(ctx, &req) + assert.Nil(t, err) + assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + + coll_meta, err := svr.mt.GetCollectionByName(sch.Name) + assert.Nil(t, err) + t.Logf("collection id = %d", coll_meta.Id) + assert.Equal(t, coll_meta.CreateTime, uint64(11)) + assert.Equal(t, coll_meta.Schema.Name, "col1") + assert.Equal(t, coll_meta.Schema.AutoId, false) + assert.Equal(t, len(coll_meta.Schema.Fields), 2) + assert.Equal(t, coll_meta.Schema.Fields[0].Name, "col1_f1") + assert.Equal(t, coll_meta.Schema.Fields[1].Name, "col1_f2") + assert.Equal(t, coll_meta.Schema.Fields[0].DataType, schemapb.DataType_VECTOR_FLOAT) + assert.Equal(t, coll_meta.Schema.Fields[1].DataType, schemapb.DataType_VECTOR_BINARY) + assert.Equal(t, len(coll_meta.Schema.Fields[0].TypeParams), 2) + assert.Equal(t, len(coll_meta.Schema.Fields[0].IndexParams), 2) + assert.Equal(t, len(coll_meta.Schema.Fields[1].TypeParams), 2) + assert.Equal(t, len(coll_meta.Schema.Fields[1].IndexParams), 2) + assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[0].Key, "col1_f1_tk1") + assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[1].Key, "col1_f1_tk2") + assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[0].Value, "col1_f1_tv1") + assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[1].Value, "col1_f1_tv2") + assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[0].Key, "col1_f1_ik1") + assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[1].Key, "col1_f1_ik2") + assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[0].Value, "col1_f1_iv1") + assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[1].Value, "col1_f1_iv2") + + assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[0].Key, "col1_f2_tk1") + assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[1].Key, "col1_f2_tk2") + assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[0].Value, "col1_f2_tv1") + assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[1].Value, "col1_f2_tv2") + assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[0].Key, "col1_f2_ik1") + assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[1].Key, "col1_f2_ik2") + assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[0].Value, "col1_f2_iv1") + assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[1].Value, "col1_f2_iv2") + + svr.Close() +} diff --git a/internal/master/master.go b/internal/master/master.go index 1ad0320ffb..f21e5ea471 100644 --- a/internal/master/master.go +++ b/internal/master/master.go @@ -3,6 +3,7 @@ package master import ( "context" "fmt" + "github.com/zilliztech/milvus-distributed/internal/errors" "log" "math/rand" "net" @@ -51,21 +52,18 @@ type Master struct { kvBase *kv.EtcdKV scheduler *ddRequestScheduler - mt metaTable + mt *metaTable // Add callback functions at different stages startCallbacks []func() closeCallbacks []func() } -func newKVBase() *kv.EtcdKV { - etcdAddr := conf.Config.Etcd.Address - etcdAddr += ":" - etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) +func newKVBase(kv_root string, etcdAddr []string) *kv.EtcdKV { cli, _ := clientv3.New(clientv3.Config{ - Endpoints: []string{etcdAddr}, + Endpoints: etcdAddr, DialTimeout: 5 * time.Second, }) - kvBase := kv.NewEtcdKV(cli, conf.Config.Etcd.Rootpath) + kvBase := kv.NewEtcdKV(cli, kv_root) return kvBase } @@ -76,12 +74,26 @@ func Init() { } // CreateServer creates the UNINITIALIZED pd server with given configuration. -func CreateServer(ctx context.Context) (*Master, error) { +func CreateServer(ctx context.Context, kv_root_path string, meta_root_path, tso_root_path string, etcdAddr []string) (*Master, error) { + rand.Seed(time.Now().UnixNano()) Init() + + etcdClient, err := clientv3.New(clientv3.Config{Endpoints: etcdAddr}) + if err != nil { + return nil, err + } + etcdkv := kv.NewEtcdKV(etcdClient, meta_root_path) + metakv, err := NewMetaTable(etcdkv) + if err != nil { + return nil, err + } + m := &Master{ ctx: ctx, startTimestamp: time.Now().Unix(), - kvBase: newKVBase(), + kvBase: newKVBase(kv_root_path, etcdAddr), + scheduler: NewDDRequestScheduler(), + mt: metakv, ssChan: make(chan internalpb.SegmentStatistics, 10), pc: informer.NewPulsarClient(), } @@ -141,13 +153,13 @@ func (s *Master) IsClosed() bool { } // Run runs the pd server. -func (s *Master) Run() error { +func (s *Master) Run(grpcPort int64) error { if err := s.startServer(s.ctx); err != nil { return err } - s.startServerLoop(s.ctx) + s.startServerLoop(s.ctx, grpcPort) return nil } @@ -162,18 +174,28 @@ func (s *Master) LoopContext() context.Context { return s.serverLoopCtx } -func (s *Master) startServerLoop(ctx context.Context) { +func (s *Master) startServerLoop(ctx context.Context, grpcPort int64) { s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(ctx) - s.serverLoopWg.Add(3) //go s.Se - go s.grpcLoop() - go s.pulsarLoop() + + s.serverLoopWg.Add(1) + go s.grpcLoop(grpcPort) + + //s.serverLoopWg.Add(1) + //go s.pulsarLoop() + + s.serverLoopWg.Add(1) + go s.tasksExecutionLoop() + + s.serverLoopWg.Add(1) go s.segmentStatisticsLoop() + } func (s *Master) stopServerLoop() { if s.grpcServer != nil { s.grpcServer.GracefulStop() + log.Printf("server is cloded, exit grpc server") } s.serverLoopCancel() s.serverLoopWg.Wait() @@ -184,11 +206,11 @@ func (s *Master) StartTimestamp() int64 { return s.startTimestamp } -func (s *Master) grpcLoop() { +func (s *Master) grpcLoop(grpcPort int64) { defer s.serverLoopWg.Done() defaultGRPCPort := ":" - defaultGRPCPort += strconv.FormatInt(int64(conf.Config.Master.Port), 10) + defaultGRPCPort += strconv.FormatInt(grpcPort, 10) lis, err := net.Listen("tcp", defaultGRPCPort) if err != nil { log.Printf("failed to listen: %v", err) @@ -235,7 +257,7 @@ func (s *Master) pulsarLoop() { s.ssChan <- m consumer.Ack(msg) case <-ctx.Done(): - log.Print("server is closed, exit etcd leader loop") + log.Print("server is closed, exit pulsar loop") return } } @@ -248,18 +270,16 @@ func (s *Master) tasksExecutionLoop() { for { select { case task := <-s.scheduler.reqQueue: - timeStamp, err := (*task).Ts() + timeStamp, err := (task).Ts() if err != nil { log.Println(err) } else { if timeStamp < s.scheduler.scheduleTimeStamp { - _ = (*task).NotifyTimeout() + task.Notify(errors.Errorf("input timestamp = %d, schduler timestamp = %d", timeStamp, s.scheduler.scheduleTimeStamp)) } else { s.scheduler.scheduleTimeStamp = timeStamp - err := (*task).Execute() - if err != nil { - log.Println("request execution failed caused by error:", err) - } + err = task.Execute() + task.Notify(err) } } case <-ctx.Done(): @@ -280,7 +300,7 @@ func (s *Master) segmentStatisticsLoop() { case ss := <-s.ssChan: controller.ComputeCloseTime(ss, s.kvBase) case <-ctx.Done(): - log.Print("server is closed, exit etcd leader loop") + log.Print("server is closed, exit segment statistics loop") return } } diff --git a/internal/master/partition_task.go b/internal/master/partition_task.go index f53f4e8303..733efb9dcb 100644 --- a/internal/master/partition_task.go +++ b/internal/master/partition_task.go @@ -59,7 +59,6 @@ func (t *createPartitionTask) Ts() (Timestamp, error) { func (t *createPartitionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -67,7 +66,6 @@ func (t *createPartitionTask) Execute() error { collectionName := partitionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { - _ = t.Notify() return err } @@ -75,18 +73,15 @@ func (t *createPartitionTask) Execute() error { collectionJson, err := json.Marshal(&collectionMeta) if err != nil { - _ = t.Notify() - return errors.New("marshal collection failed") + return err } collectionId := collectionMeta.Id err = (*t.kvBase).Save(partitionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) if err != nil { - _ = t.Notify() - return errors.New("save collection failed") + return err } - _ = t.Notify() return nil } @@ -108,7 +103,6 @@ func (t *dropPartitionTask) Ts() (Timestamp, error) { func (t *dropPartitionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -116,7 +110,6 @@ func (t *dropPartitionTask) Execute() error { collectionName := partitionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { - _ = t.Notify() return err } @@ -127,18 +120,15 @@ func (t *dropPartitionTask) Execute() error { collectionJson, err := json.Marshal(&collectionMeta) if err != nil { - _ = t.Notify() - return errors.New("marshal collection failed") + return err } collectionId := collectionMeta.Id err = (*t.kvBase).Save(partitionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) if err != nil { - _ = t.Notify() - return errors.New("save collection failed") + return err } - _ = t.Notify() return nil } @@ -160,7 +150,6 @@ func (t *hasPartitionTask) Ts() (Timestamp, error) { func (t *hasPartitionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -173,7 +162,6 @@ func (t *hasPartitionTask) Execute() error { t.hasPartition = t.mt.HasPartition(collectionMeta.Id, partitionName.Tag) - _ = t.Notify() return nil } @@ -195,7 +183,6 @@ func (t *describePartitionTask) Ts() (Timestamp, error) { func (t *describePartitionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -210,7 +197,6 @@ func (t *describePartitionTask) Execute() error { t.description = &description - _ = t.Notify() return nil } @@ -232,7 +218,6 @@ func (t *showPartitionTask) Ts() (Timestamp, error) { func (t *showPartitionTask) Execute() error { if t.req == nil { - _ = t.Notify() return errors.New("null request") } @@ -252,6 +237,5 @@ func (t *showPartitionTask) Execute() error { t.stringListResponse = &stringListResponse - _ = t.Notify() return nil } diff --git a/internal/master/scheduler.go b/internal/master/scheduler.go index 2c13fa3564..6dcc3bc1df 100644 --- a/internal/master/scheduler.go +++ b/internal/master/scheduler.go @@ -1,7 +1,9 @@ package master +import "math/rand" + type ddRequestScheduler struct { - reqQueue chan *task + reqQueue chan task scheduleTimeStamp Timestamp } @@ -9,12 +11,17 @@ func NewDDRequestScheduler() *ddRequestScheduler { const channelSize = 1024 rs := ddRequestScheduler{ - reqQueue: make(chan *task, channelSize), + reqQueue: make(chan task, channelSize), } return &rs } -func (rs *ddRequestScheduler) Enqueue(task *task) error { +func (rs *ddRequestScheduler) Enqueue(task task) error { rs.reqQueue <- task return nil } + +//TODO, allocGlobalId +func allocGlobalId() (UniqueID, error) { + return rand.Int63(), nil +} diff --git a/internal/master/task.go b/internal/master/task.go index d29dbd8bb1..0c6bab5f4c 100644 --- a/internal/master/task.go +++ b/internal/master/task.go @@ -13,7 +13,7 @@ import ( type baseTask struct { kvBase *kv.EtcdKV mt *metaTable - cv chan int + cv chan error } type task interface { @@ -21,27 +21,23 @@ type task interface { Ts() (Timestamp, error) Execute() error WaitToFinish(ctx context.Context) error - Notify() error - NotifyTimeout() error + Notify(err error) } -func (bt *baseTask) Notify() error { - bt.cv <- 0 - return nil -} - -func (bt *baseTask) NotifyTimeout() error { - bt.cv <- 0 - return errors.New("request timeout") +func (bt *baseTask) Notify(err error) { + bt.cv <- err } func (bt *baseTask) WaitToFinish(ctx context.Context) error { for { select { case <-ctx.Done(): - return nil - case <-bt.cv: - return nil + return errors.Errorf("context done") + case err, ok := <-bt.cv: + if !ok { + return errors.Errorf("notify chan closed") + } + return err } } }