diff --git a/internal/allocator/id.go b/internal/allocator/id.go index 29a49f141a..6bdd969d3c 100644 --- a/internal/allocator/id.go +++ b/internal/allocator/id.go @@ -35,6 +35,7 @@ type UniqueID = typeutil.UniqueID type IDAllocator struct { Allocator + etcdAddr []string masterAddress string masterClient types.MasterService @@ -46,7 +47,7 @@ type IDAllocator struct { PeerID UniqueID } -func NewIDAllocator(ctx context.Context, masterAddr string) (*IDAllocator, error) { +func NewIDAllocator(ctx context.Context, masterAddr string, etcdAddr []string) (*IDAllocator, error) { ctx1, cancel := context.WithCancel(ctx) a := &IDAllocator{ @@ -56,6 +57,7 @@ func NewIDAllocator(ctx context.Context, masterAddr string) (*IDAllocator, error Role: "IDAllocator", }, countPerRPC: IDCountPerRPC, + etcdAddr: etcdAddr, masterAddress: masterAddr, } a.TChan = &EmptyTicker{} @@ -69,7 +71,8 @@ func NewIDAllocator(ctx context.Context, masterAddr string) (*IDAllocator, error func (ia *IDAllocator) Start() error { var err error - ia.masterClient, err = msc.NewClient(ia.masterAddress, 20*time.Second) + + ia.masterClient, err = msc.NewClient(ia.masterAddress, ia.etcdAddr, 20*time.Second) if err != nil { panic(err) } diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index e814b0f0f3..d00787a540 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -138,9 +138,8 @@ func (node *DataNode) SetDataServiceInterface(ds types.DataService) error { func (node *DataNode) Init() error { ctx := context.Background() - node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.DataNodeRole, - Params.IP+":"+strconv.Itoa(Params.Port), false) - node.session.Init() + node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + node.session.Init(typeutil.DataNodeRole, Params.IP+":"+strconv.Itoa(Params.Port), false) req := &datapb.RegisterNodeRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index d6369d01a7..80939a168c 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -108,9 +108,8 @@ func (s *Server) SetMasterClient(masterClient types.MasterService) { } func (s *Server) Init() error { - s.session = sessionutil.NewSession(s.ctx, []string{Params.EtcdAddress}, typeutil.DataServiceRole, - Params.IP, true) - s.session.Init() + s.session = sessionutil.NewSession(s.ctx, []string{Params.EtcdAddress}) + s.session.Init(typeutil.DataServiceRole, Params.IP, true) return nil } diff --git a/internal/dataservice/server_test.go b/internal/dataservice/server_test.go index 7c95036a48..f0e6ab1df8 100644 --- a/internal/dataservice/server_test.go +++ b/internal/dataservice/server_test.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" "go.etcd.io/etcd/clientv3" @@ -787,7 +788,7 @@ func newTestServer(t *testing.T) *Server { etcdCli, err := initEtcd(Params.EtcdAddress) assert.Nil(t, err) - _, err = etcdCli.Delete(context.Background(), "/session", clientv3.WithPrefix()) + _, err = etcdCli.Delete(context.Background(), sessionutil.DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) svr, err := CreateServer(context.TODO(), factory) diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 0ec31b52bb..e3493e31ce 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -70,7 +70,7 @@ func NewServer(ctx context.Context, factory msgstream.Factory) (*Server, error) msFactory: factory, grpcErrChan: make(chan error), newMasterServiceClient: func(s string) (types.MasterService, error) { - return msc.NewClient(s, 20*time.Second) + return msc.NewClient(s, []string{dn.Params.EtcdAddress}, 20*time.Second) }, newDataServiceClient: func(s string) types.DataService { return dsc.NewClient(Params.DataServiceAddress) diff --git a/internal/distributed/dataservice/service.go b/internal/distributed/dataservice/service.go index d16b1cb868..ab296e9794 100644 --- a/internal/distributed/dataservice/service.go +++ b/internal/distributed/dataservice/service.go @@ -68,7 +68,7 @@ func NewServer(ctx context.Context, factory msgstream.Factory) (*Server, error) cancel: cancel, grpcErrChan: make(chan error), newMasterServiceClient: func(s string) (types.MasterService, error) { - return msc.NewClient(s, 10*time.Second) + return msc.NewClient(s, []string{dataservice.Params.EtcdAddress}, 10*time.Second) }, } s.dataService, err = dataservice.CreateServer(s.ctx, factory) diff --git a/internal/distributed/masterservice/client/client.go b/internal/distributed/masterservice/client/client.go index 68c87f4674..3f51ef223c 100644 --- a/internal/distributed/masterservice/client/client.go +++ b/internal/distributed/masterservice/client/client.go @@ -13,12 +13,16 @@ package grpcmasterserviceclient import ( "context" + "fmt" + "path" "time" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/masterpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "google.golang.org/grpc" @@ -30,30 +34,59 @@ type GrpcClient struct { conn *grpc.ClientConn //inner member - addr string - timeout time.Duration - grpcTimeout time.Duration - retry int + addr string + timeout time.Duration + reconnTry int + recallTry int + + sess *sessionutil.Session } -func NewClient(addr string, timeout time.Duration) (*GrpcClient, error) { +func getMasterServiceAddr(sess *sessionutil.Session) (string, error) { + msess, err := sess.GetSessions(typeutil.MasterServiceRole) + if err != nil { + return "", err + } + key := path.Join(sessionutil.DefaultServiceRoot, typeutil.MasterServiceRole) + var ms *sessionutil.Session + var ok bool + if ms, ok = msess[key]; !ok { + return "", fmt.Errorf("number of master service is incorrect, %d", len(msess)) + } + return ms.Address, nil +} + +func NewClient(addr string, etcdAddr []string, timeout time.Duration) (*GrpcClient, error) { + sess := sessionutil.NewSession(context.Background(), etcdAddr) + + if addr == "" { + var err error + if addr, err = getMasterServiceAddr(sess); err != nil { + return nil, err + } + } + return &GrpcClient{ - grpcClient: nil, - conn: nil, - addr: addr, - timeout: timeout, - grpcTimeout: time.Second * 5, - retry: 300, + grpcClient: nil, + conn: nil, + addr: addr, + timeout: timeout, + reconnTry: 300, + recallTry: 3, + sess: sess, }, nil } -func (c *GrpcClient) Init() error { +func (c *GrpcClient) reconnect() error { + addr, err := getMasterServiceAddr(c.sess) + if err != nil { + return nil + } tracer := opentracing.GlobalTracer() ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() - var err error - for i := 0; i < c.retry; i++ { - if c.conn, err = grpc.DialContext(ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), + for i := 0; i < c.reconnTry; i++ { + if c.conn, err = grpc.DialContext(ctx, addr, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithUnaryInterceptor( otgrpc.OpenTracingClientInterceptor(tracer)), grpc.WithStreamInterceptor( @@ -68,6 +101,27 @@ func (c *GrpcClient) Init() error { return nil } +func (c *GrpcClient) Init() error { + tracer := opentracing.GlobalTracer() + ctx, cancel := context.WithTimeout(context.Background(), c.timeout) + defer cancel() + var err error + for i := 0; i < c.reconnTry; i++ { + if c.conn, err = grpc.DialContext(ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithUnaryInterceptor( + otgrpc.OpenTracingClientInterceptor(tracer)), + grpc.WithStreamInterceptor( + otgrpc.OpenTracingStreamClientInterceptor(tracer))); err == nil { + break + } + } + if err != nil { + return c.reconnect() + } + c.grpcClient = masterpb.NewMasterServiceClient(c.conn) + return nil +} + func (c *GrpcClient) Start() error { return nil } @@ -75,91 +129,171 @@ func (c *GrpcClient) Stop() error { return c.conn.Close() } +func (c *GrpcClient) recall(caller func() (interface{}, error)) (interface{}, error) { + ret, err := caller() + if err == nil { + return ret, nil + } + for i := 0; i < c.recallTry; i++ { + err = c.reconnect() + if err == nil { + ret, err = caller() + if err == nil { + return ret, nil + } + } + } + return ret, err +} + // TODO: timeout need to be propagated through ctx func (c *GrpcClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + }) + return ret.(*internalpb.ComponentStates), err } func (c *GrpcClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } //just define a channel, not used currently func (c *GrpcClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } //receive ddl from rpc and time tick from proxy service, and put them into this channel func (c *GrpcClient) GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetDdChannel(ctx, &internalpb.GetDdChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetDdChannel(ctx, &internalpb.GetDdChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } //DDL request func (c *GrpcClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - return c.grpcClient.CreateCollection(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.CreateCollection(ctx, in) + }) + return ret.(*commonpb.Status), err } func (c *GrpcClient) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - return c.grpcClient.DropCollection(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DropCollection(ctx, in) + }) + return ret.(*commonpb.Status), err } + func (c *GrpcClient) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - return c.grpcClient.HasCollection(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.HasCollection(ctx, in) + }) + return ret.(*milvuspb.BoolResponse), err } func (c *GrpcClient) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - return c.grpcClient.DescribeCollection(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DescribeCollection(ctx, in) + }) + return ret.(*milvuspb.DescribeCollectionResponse), err } func (c *GrpcClient) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - return c.grpcClient.ShowCollections(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowCollections(ctx, in) + }) + return ret.(*milvuspb.ShowCollectionsResponse), err } - func (c *GrpcClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - return c.grpcClient.CreatePartition(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.CreatePartition(ctx, in) + }) + return ret.(*commonpb.Status), err } func (c *GrpcClient) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - return c.grpcClient.DropPartition(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DropPartition(ctx, in) + }) + return ret.(*commonpb.Status), err } func (c *GrpcClient) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - return c.grpcClient.HasPartition(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.HasPartition(ctx, in) + }) + return ret.(*milvuspb.BoolResponse), err } func (c *GrpcClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - return c.grpcClient.ShowPartitions(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowPartitions(ctx, in) + }) + return ret.(*milvuspb.ShowPartitionsResponse), err } //index builder service func (c *GrpcClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { - return c.grpcClient.CreateIndex(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.CreateIndex(ctx, in) + }) + return ret.(*commonpb.Status), err } func (c *GrpcClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest) (*commonpb.Status, error) { - return c.grpcClient.DropIndex(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DropIndex(ctx, in) + }) + return ret.(*commonpb.Status), err } func (c *GrpcClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { - return c.grpcClient.DescribeIndex(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DescribeIndex(ctx, in) + }) + return ret.(*milvuspb.DescribeIndexResponse), err } //global timestamp allocator func (c *GrpcClient) AllocTimestamp(ctx context.Context, in *masterpb.AllocTimestampRequest) (*masterpb.AllocTimestampResponse, error) { - return c.grpcClient.AllocTimestamp(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.AllocTimestamp(ctx, in) + }) + return ret.(*masterpb.AllocTimestampResponse), err } func (c *GrpcClient) AllocID(ctx context.Context, in *masterpb.AllocIDRequest) (*masterpb.AllocIDResponse, error) { - return c.grpcClient.AllocID(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.AllocID(ctx, in) + }) + return ret.(*masterpb.AllocIDResponse), err } // UpdateChannelTimeTick used to handle ChannelTimeTickMsg func (c *GrpcClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - return c.grpcClient.UpdateChannelTimeTick(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.UpdateChannelTimeTick(ctx, in) + }) + return ret.(*commonpb.Status), err } //receiver time tick from proxy service, and put it into this channel func (c *GrpcClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { - return c.grpcClient.DescribeSegment(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.DescribeSegment(ctx, in) + }) + return ret.(*milvuspb.DescribeSegmentResponse), err } func (c *GrpcClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - return c.grpcClient.ShowSegments(ctx, in) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowSegments(ctx, in) + }) + return ret.(*milvuspb.ShowSegmentsResponse), err } diff --git a/internal/distributed/masterservice/masterservice_test.go b/internal/distributed/masterservice/masterservice_test.go index f977415b49..12e4b56703 100644 --- a/internal/distributed/masterservice/masterservice_test.go +++ b/internal/distributed/masterservice/masterservice_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" "go.etcd.io/etcd/clientv3" @@ -86,7 +87,7 @@ func TestGrpcService(t *testing.T) { etcdCli, err := initEtcd(cms.Params.EtcdAddress) assert.Nil(t, err) - _, err = etcdCli.Delete(ctx, "/session", clientv3.WithPrefix()) + _, err = etcdCli.Delete(ctx, sessionutil.DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) err = core.Init() @@ -171,7 +172,7 @@ func TestGrpcService(t *testing.T) { svr.masterService.UpdateStateCode(internalpb.StateCode_Healthy) - cli, err := grpcmasterserviceclient.NewClient(Params.Address, 3*time.Second) + cli, err := grpcmasterserviceclient.NewClient(Params.Address, []string{cms.Params.EtcdAddress}, 3*time.Second) assert.Nil(t, err) err = cli.Init() @@ -871,7 +872,7 @@ func TestRun(t *testing.T) { etcdCli, err := initEtcd(cms.Params.EtcdAddress) assert.Nil(t, err) - _, err = etcdCli.Delete(ctx, "/session", clientv3.WithPrefix()) + _, err = etcdCli.Delete(ctx, sessionutil.DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) err = svr.Run() assert.Nil(t, err) diff --git a/internal/distributed/proxynode/service.go b/internal/distributed/proxynode/service.go index 0aab245e5b..ded087a135 100644 --- a/internal/distributed/proxynode/service.go +++ b/internal/distributed/proxynode/service.go @@ -184,7 +184,7 @@ func (s *Server) init() error { masterServiceAddr := Params.MasterAddress log.Debug("proxynode", zap.String("master address", masterServiceAddr)) timeout := 3 * time.Second - s.masterServiceClient, err = grpcmasterserviceclient.NewClient(masterServiceAddr, timeout) + s.masterServiceClient, err = grpcmasterserviceclient.NewClient(masterServiceAddr, []string{proxynode.Params.EtcdAddress}, timeout) if err != nil { return err } diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index c768dfc919..56f88011f2 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -130,7 +130,7 @@ func (s *Server) init() error { log.Debug("Master service", zap.String("address", addr)) log.Debug("Init master service client ...") - masterService, err := msc.NewClient(addr, 20*time.Second) + masterService, err := msc.NewClient(addr, []string{qn.Params.EtcdAddress}, 20*time.Second) if err != nil { panic(err) } diff --git a/internal/distributed/queryservice/service.go b/internal/distributed/queryservice/service.go index 492c23201a..65c83138c3 100644 --- a/internal/distributed/queryservice/service.go +++ b/internal/distributed/queryservice/service.go @@ -107,7 +107,7 @@ func (s *Server) init() error { log.Debug("Master service", zap.String("address", Params.MasterAddress)) log.Debug("Init master service client ...") - masterService, err := msc.NewClient(Params.MasterAddress, 20*time.Second) + masterService, err := msc.NewClient(Params.MasterAddress, []string{qs.Params.EtcdAddress}, 20*time.Second) if err != nil { panic(err) diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 1ed458be6c..1361b2e4b5 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -79,9 +79,8 @@ func NewIndexNode(ctx context.Context) (*IndexNode, error) { func (i *IndexNode) Init() error { ctx := context.Background() - i.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.IndexNodeRole, - Params.IP+":"+strconv.Itoa(Params.Port), false) - i.session.Init() + i.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + i.session.Init(typeutil.IndexNodeRole, Params.IP+":"+strconv.Itoa(Params.Port), false) err := funcutil.WaitForComponentHealthy(ctx, i.serviceClient, "IndexService", 1000000, time.Millisecond*200) if err != nil { diff --git a/internal/indexservice/indexservice.go b/internal/indexservice/indexservice.go index 516af34f84..73fb4c82bd 100644 --- a/internal/indexservice/indexservice.go +++ b/internal/indexservice/indexservice.go @@ -89,9 +89,8 @@ func (i *IndexService) Init() error { log.Debug("indexservice", zap.String("etcd address", Params.EtcdAddress)) ctx := context.Background() - i.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.IndexServiceRole, - Params.Address, true) - i.session.Init() + i.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + i.session.Init(typeutil.IndexServiceRole, Params.Address, true) connectEtcdFn := func() error { etcdClient, err := clientv3.New(clientv3.Config{Endpoints: []string{Params.EtcdAddress}}) diff --git a/internal/masterservice/master_service.go b/internal/masterservice/master_service.go index 5d86cf3683..f3755aba38 100644 --- a/internal/masterservice/master_service.go +++ b/internal/masterservice/master_service.go @@ -821,9 +821,8 @@ func (c *Core) BuildIndex(segID typeutil.UniqueID, field *schemapb.FieldSchema, func (c *Core) Init() error { var initError error = nil c.initOnce.Do(func() { - c.session = sessionutil.NewSession(c.ctx, []string{Params.EtcdAddress}, typeutil.MasterServiceRole, - Params.Address, true) - c.session.Init() + c.session = sessionutil.NewSession(c.ctx, []string{Params.EtcdAddress}) + c.session.Init(typeutil.MasterServiceRole, Params.Address, true) connectEtcdFn := func() error { if c.etcdCli, initError = clientv3.New(clientv3.Config{Endpoints: []string{Params.EtcdAddress}, DialTimeout: 5 * time.Second}); initError != nil { diff --git a/internal/masterservice/master_service_test.go b/internal/masterservice/master_service_test.go index 33995be4b2..56f737c858 100644 --- a/internal/masterservice/master_service_test.go +++ b/internal/masterservice/master_service_test.go @@ -16,6 +16,7 @@ import ( "encoding/json" "fmt" "math/rand" + "path" "sync" "testing" "time" @@ -218,10 +219,10 @@ func TestMasterService(t *testing.T) { etcdCli, err := clientv3.New(clientv3.Config{Endpoints: []string{Params.EtcdAddress}, DialTimeout: 5 * time.Second}) assert.Nil(t, err) - _, err = etcdCli.Delete(ctx, ProxyNodeSessionPrefix, clientv3.WithPrefix()) + _, err = etcdCli.Delete(ctx, sessionutil.DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) defer func() { - _, _ = etcdCli.Delete(ctx, ProxyNodeSessionPrefix, clientv3.WithPrefix()) + _, _ = etcdCli.Delete(ctx, sessionutil.DefaultServiceRoot, clientv3.WithPrefix()) }() pm := &proxyMock{ @@ -253,14 +254,6 @@ func TestMasterService(t *testing.T) { err = core.SetQueryService(qm) assert.Nil(t, err) - //TODO initialize master's session manager before core init - /* - self := sessionutil.NewSession("masterservice", funcutil.GetLocalIP()+":"+strconv.Itoa(53100), true) - sm := sessionutil.NewSessionManager(ctx, Params.EtcdAddress, Params.MetaRootPath, self) - sm.Init() - sessionutil.SetGlobalSessionManager(sm) - */ - err = core.Init() assert.Nil(t, err) @@ -1458,9 +1451,9 @@ func TestMasterService(t *testing.T) { s2, err := json.Marshal(&p2) assert.Nil(t, err) - _, err = core.etcdCli.Put(ctx2, ProxyNodeSessionPrefix+"-1", string(s1)) + _, err = core.etcdCli.Put(ctx2, path.Join(sessionutil.DefaultServiceRoot, typeutil.ProxyNodeRole)+"-1", string(s1)) assert.Nil(t, err) - _, err = core.etcdCli.Put(ctx2, ProxyNodeSessionPrefix+"-2", string(s2)) + _, err = core.etcdCli.Put(ctx2, path.Join(sessionutil.DefaultServiceRoot, typeutil.ProxyNodeRole)+"-2", string(s2)) assert.Nil(t, err) time.Sleep(time.Second) diff --git a/internal/masterservice/timeticksync.go b/internal/masterservice/timeticksync.go index 50bec1fb70..c727a4fe7e 100644 --- a/internal/masterservice/timeticksync.go +++ b/internal/masterservice/timeticksync.go @@ -15,6 +15,7 @@ import ( "context" "encoding/json" "fmt" + "path" "sync" "github.com/coreos/etcd/mvcc/mvccpb" @@ -38,9 +39,6 @@ type timetickSync struct { sendChan chan map[typeutil.UniqueID]*internalpb.ChannelTimeTickMsg } -// ProxyNodeSessionPrefix used for etcd watch -const ProxyNodeSessionPrefix = "session/proxynode" - func newTimeTickSync(ctx context.Context, factory msgstream.Factory, cli *clientv3.Client) (*timetickSync, error) { tss := timetickSync{ lock: sync.Mutex{}, @@ -109,7 +107,8 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg) error { // StartWatch watch proxy node change and process all channels' timetick msg func (t *timetickSync) StartWatch() { - rch := t.etcdCli.Watch(t.ctx, ProxyNodeSessionPrefix, clientv3.WithPrefix(), clientv3.WithCreatedNotify()) + proxyNodePrefix := path.Join(sessionutil.DefaultServiceRoot, typeutil.ProxyNodeRole) + rch := t.etcdCli.Watch(t.ctx, proxyNodePrefix, clientv3.WithPrefix(), clientv3.WithCreatedNotify()) for { select { case <-t.ctx.Done(): diff --git a/internal/proxynode/proxy_node.go b/internal/proxynode/proxy_node.go index dbea072d83..e43c8dc788 100644 --- a/internal/proxynode/proxy_node.go +++ b/internal/proxynode/proxy_node.go @@ -89,9 +89,8 @@ func (node *ProxyNode) Init() error { // todo wait for proxyservice state changed to Healthy ctx := context.Background() - node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.ProxyNodeRole, - Params.NetworkAddress, false) - node.session.Init() + node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + node.session.Init(typeutil.ProxyNodeRole, Params.NetworkAddress, false) err := funcutil.WaitForComponentHealthy(ctx, node.proxyService, "ProxyService", 1000000, time.Millisecond*200) if err != nil { @@ -177,7 +176,7 @@ func (node *ProxyNode) Init() error { log.Debug("create query message stream ...") masterAddr := Params.MasterAddress - idAllocator, err := allocator.NewIDAllocator(node.ctx, masterAddr) + idAllocator, err := allocator.NewIDAllocator(node.ctx, masterAddr, []string{Params.EtcdAddress}) if err != nil { return err diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index eb614975a9..e86e6fd722 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -121,9 +121,8 @@ func NewQueryNodeWithoutID(ctx context.Context, factory msgstream.Factory) *Quer func (node *QueryNode) Init() error { ctx := context.Background() - node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.QueryNodeRole, - Params.QueryNodeIP+":"+strconv.FormatInt(Params.QueryNodePort, 10), false) - node.session.Init() + node.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + node.session.Init(typeutil.QueryNodeRole, Params.QueryNodeIP+":"+strconv.FormatInt(Params.QueryNodePort, 10), false) C.SegcoreInit() registerReq := &queryPb.RegisterNodeRequest{ diff --git a/internal/queryservice/queryservice.go b/internal/queryservice/queryservice.go index 7d80d93e09..028b15e660 100644 --- a/internal/queryservice/queryservice.go +++ b/internal/queryservice/queryservice.go @@ -59,9 +59,8 @@ type QueryService struct { func (qs *QueryService) Init() error { ctx := context.Background() - qs.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}, typeutil.QueryServiceRole, - Params.Address, true) - qs.session.Init() + qs.session = sessionutil.NewSession(ctx, []string{Params.EtcdAddress}) + qs.session.Init(typeutil.QueryServiceRole, Params.Address, true) return nil } diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 11c85caa88..03a3267622 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -16,10 +16,10 @@ import ( "go.uber.org/zap" ) -const defaultServiceRoot = "/session/" -const defaultIDKey = "id" -const defaultRetryTimes = 30 -const defaultTTL = 10 +const DefaultServiceRoot = "/session/" +const DefaultIDKey = "id" +const DefaultRetryTimes = 30 +const DefaultTTL = 10 // Session is a struct to store service's session, including ServerID, ServerName, // Address. @@ -38,14 +38,11 @@ type Session struct { // NewSession is a helper to build Session object.LeaseID will be assigned after // registeration. -func NewSession(ctx context.Context, etcdAddress []string, serverName, address string, exclusive bool) *Session { +func NewSession(ctx context.Context, etcdAddress []string) *Session { ctx, cancel := context.WithCancel(ctx) session := &Session{ - ctx: ctx, - ServerName: serverName, - Address: address, - Exclusive: exclusive, - cancel: cancel, + ctx: ctx, + cancel: cancel, } connectEtcdFn := func() error { @@ -65,7 +62,10 @@ func NewSession(ctx context.Context, etcdAddress []string, serverName, address s // Init will initialize base struct in the SessionManager, including getServerID, // and process keepAliveResponse -func (s *Session) Init() { +func (s *Session) Init(serverName, address string, exclusive bool) { + s.ServerName = serverName + s.Address = address + s.Exclusive = exclusive s.checkIDExist() serverID, err := s.getServerID() if err != nil { @@ -82,23 +82,23 @@ func (s *Session) Init() { // GetServerID gets id from etcd with key: metaRootPath + "/services/id" // Each server get ServerID and add one to id. func (s *Session) getServerID() (int64, error) { - return s.getServerIDWithKey(defaultIDKey, defaultRetryTimes) + return s.getServerIDWithKey(DefaultIDKey, DefaultRetryTimes) } func (s *Session) checkIDExist() { s.etcdCli.Txn(s.ctx).If( clientv3.Compare( - clientv3.Version(path.Join(defaultServiceRoot, defaultIDKey)), + clientv3.Version(path.Join(DefaultServiceRoot, DefaultIDKey)), "=", 0)). - Then(clientv3.OpPut(path.Join(defaultServiceRoot, defaultIDKey), "1")).Commit() + Then(clientv3.OpPut(path.Join(DefaultServiceRoot, DefaultIDKey), "1")).Commit() } func (s *Session) getServerIDWithKey(key string, retryTimes int) (int64, error) { res := int64(0) getServerIDWithKeyFn := func() error { - getResp, err := s.etcdCli.Get(s.ctx, path.Join(defaultServiceRoot, key)) + getResp, err := s.etcdCli.Get(s.ctx, path.Join(DefaultServiceRoot, key)) if err != nil { return nil } @@ -113,10 +113,10 @@ func (s *Session) getServerIDWithKey(key string, retryTimes int) (int64, error) } txnResp, err := s.etcdCli.Txn(s.ctx).If( clientv3.Compare( - clientv3.Value(path.Join(defaultServiceRoot, defaultIDKey)), + clientv3.Value(path.Join(DefaultServiceRoot, DefaultIDKey)), "=", value)). - Then(clientv3.OpPut(path.Join(defaultServiceRoot, defaultIDKey), strconv.FormatInt(valueInt+1, 10))).Commit() + Then(clientv3.OpPut(path.Join(DefaultServiceRoot, DefaultIDKey), strconv.FormatInt(valueInt+1, 10))).Commit() if err != nil { return err } @@ -149,7 +149,7 @@ func (s *Session) getServerIDWithKey(key string, retryTimes int) (int64, error) func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, error) { var ch <-chan *clientv3.LeaseKeepAliveResponse registerFn := func() error { - resp, err := s.etcdCli.Grant(s.ctx, defaultTTL) + resp, err := s.etcdCli.Grant(s.ctx, DefaultTTL) if err != nil { log.Error("register service", zap.Error(err)) return err @@ -167,10 +167,10 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er } txnResp, err := s.etcdCli.Txn(s.ctx).If( clientv3.Compare( - clientv3.Version(path.Join(defaultServiceRoot, key)), + clientv3.Version(path.Join(DefaultServiceRoot, key)), "=", 0)). - Then(clientv3.OpPut(path.Join(defaultServiceRoot, key), string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() + Then(clientv3.OpPut(path.Join(DefaultServiceRoot, key), string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() if err != nil { fmt.Printf("compare and swap error %s\n. maybe the key has registered", err) @@ -188,7 +188,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er } return nil } - err := retry.Retry(defaultRetryTimes, time.Millisecond*200, registerFn) + err := retry.Retry(DefaultRetryTimes, time.Millisecond*200, registerFn) if err != nil { return ch, nil } @@ -221,7 +221,7 @@ func (s *Session) processKeepAliveResponse(ch <-chan *clientv3.LeaseKeepAliveRes // GetSessions will get all sessions registered in etcd. func (s *Session) GetSessions(prefix string) (map[string]*Session, error) { res := make(map[string]*Session) - key := path.Join(defaultServiceRoot, prefix) + key := path.Join(DefaultServiceRoot, prefix) resp, err := s.etcdCli.Get(s.ctx, key, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) if err != nil { @@ -245,7 +245,7 @@ func (s *Session) GetSessions(prefix string) (map[string]*Session, error) { func (s *Session) WatchServices(prefix string) (addChannel <-chan *Session, delChannel <-chan *Session) { addCh := make(chan *Session, 10) delCh := make(chan *Session, 10) - rch := s.etcdCli.Watch(s.ctx, path.Join(defaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV()) + rch := s.etcdCli.Watch(s.ctx, path.Join(DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV()) go func() { for { select { diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index f30e51b1cd..5a897d6fb0 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -26,7 +26,7 @@ func TestGetServerIDConcurrently(t *testing.T) { cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}}) assert.Nil(t, err) etcdKV := etcdkv.NewEtcdKV(cli, "") - _, err = cli.Delete(ctx, "/session", clientv3.WithPrefix()) + _, err = cli.Delete(ctx, DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) defer etcdKV.Close() @@ -35,7 +35,7 @@ func TestGetServerIDConcurrently(t *testing.T) { var wg sync.WaitGroup var muList sync.Mutex = sync.Mutex{} - s := NewSession(ctx, []string{etcdAddr}, "test", "testAddr", false) + s := NewSession(ctx, []string{etcdAddr}) res := make([]int64, 0) getIDFunc := func() { @@ -71,15 +71,16 @@ func TestInit(t *testing.T) { cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}}) assert.Nil(t, err) etcdKV := etcdkv.NewEtcdKV(cli, "") - _, err = cli.Delete(ctx, "/session", clientv3.WithPrefix()) + _, err = cli.Delete(ctx, DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") - s := NewSession(ctx, []string{etcdAddr}, "test", "testAddr", false) - assert.NotEqual(t, 0, s.leaseID) - assert.NotEqual(t, 0, s.ServerID) + s := NewSession(ctx, []string{etcdAddr}) + s.Init("test", "testAddr", false) + assert.NotEqual(t, int64(0), s.leaseID) + assert.NotEqual(t, int64(0), s.ServerID) } func TestUpdateSessions(t *testing.T) { @@ -94,7 +95,7 @@ func TestUpdateSessions(t *testing.T) { cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}}) assert.Nil(t, err) etcdKV := etcdkv.NewEtcdKV(cli, "") - _, err = cli.Delete(ctx, "/session", clientv3.WithPrefix()) + _, err = cli.Delete(ctx, DefaultServiceRoot, clientv3.WithPrefix()) assert.Nil(t, err) defer etcdKV.Close() @@ -103,7 +104,7 @@ func TestUpdateSessions(t *testing.T) { var wg sync.WaitGroup var muList sync.Mutex = sync.Mutex{} - s := NewSession(ctx, []string{etcdAddr}, "test", "testAddr", false) + s := NewSession(ctx, []string{etcdAddr}) sessions, err := s.GetSessions("test") assert.Nil(t, err) @@ -113,8 +114,8 @@ func TestUpdateSessions(t *testing.T) { sList := []*Session{} getIDFunc := func() { - singleS := NewSession(ctx, []string{etcdAddr}, "test", "testAddr", false) - singleS.Init() + singleS := NewSession(ctx, []string{etcdAddr}) + singleS.Init("test", "testAddr", false) muList.Lock() sList = append(sList, singleS) muList.Unlock()