diff --git a/internal/distributed/masterservice/masterservice_test.go b/internal/distributed/masterservice/masterservice_test.go index c7d8a7c451..213ed34d0b 100644 --- a/internal/distributed/masterservice/masterservice_test.go +++ b/internal/distributed/masterservice/masterservice_test.go @@ -961,7 +961,7 @@ func TestRun(t *testing.T) { svr.newIndexServiceClient = func(s, etcdAddress, metaRootPath string, timeout time.Duration) types.IndexService { return &mockIndex{} } - svr.newQueryServiceClient = func(s string) (types.QueryService, error) { + svr.newQueryServiceClient = func(s, metaRootPath, etcdAddress string) (types.QueryService, error) { return &mockQuery{}, nil } diff --git a/internal/distributed/masterservice/server.go b/internal/distributed/masterservice/server.go index 31a3d3c49b..56d4329cf1 100644 --- a/internal/distributed/masterservice/server.go +++ b/internal/distributed/masterservice/server.go @@ -63,7 +63,7 @@ type Server struct { newProxyServiceClient func(string) types.ProxyService newIndexServiceClient func(string, string, string, time.Duration) types.IndexService newDataServiceClient func(string, string, string, time.Duration) types.DataService - newQueryServiceClient func(string) (types.QueryService, error) + newQueryServiceClient func(string, string, string) (types.QueryService, error) closer io.Closer } @@ -113,7 +113,7 @@ func (s *Server) setClient() { } return dsClient } - s.newIndexServiceClient = func(s, etcdAddress, metaRootPath string, timeout time.Duration) types.IndexService { + s.newIndexServiceClient = func(s, metaRootPath, etcdAddress string, timeout time.Duration) types.IndexService { isClient := isc.NewClient(s, metaRootPath, []string{etcdAddress}, timeout) if err := isClient.Init(); err != nil { panic(err) @@ -123,8 +123,8 @@ func (s *Server) setClient() { } return isClient } - s.newQueryServiceClient = func(s string) (types.QueryService, error) { - qsClient, err := qsc.NewClient(s, 5*time.Second) + s.newQueryServiceClient = func(s, metaRootPath, etcdAddress string) (types.QueryService, error) { + qsClient, err := qsc.NewClient(context.Background(), s, metaRootPath, []string{etcdAddress}, 5*time.Second) if err != nil { panic(err) } @@ -214,7 +214,7 @@ func (s *Server) init() error { } if s.newQueryServiceClient != nil { log.Debug("query service", zap.String("address", Params.QueryServiceAddress)) - queryService, _ := s.newQueryServiceClient(Params.QueryServiceAddress) + queryService, _ := s.newQueryServiceClient(Params.QueryServiceAddress, cms.Params.MetaRootPath, cms.Params.EtcdAddress) if err := s.masterService.SetQueryService(queryService); err != nil { panic(err) } diff --git a/internal/distributed/proxynode/service.go b/internal/distributed/proxynode/service.go index f77a01c047..cf19b92d1a 100644 --- a/internal/distributed/proxynode/service.go +++ b/internal/distributed/proxynode/service.go @@ -227,7 +227,7 @@ func (s *Server) init() error { queryServiceAddr := Params.QueryServiceAddress log.Debug("proxynode", zap.String("query server address", queryServiceAddr)) - s.queryServiceClient, err = grpcqueryserviceclient.NewClient(queryServiceAddr, timeout) + s.queryServiceClient, err = grpcqueryserviceclient.NewClient(ctx, queryServiceAddr, proxynode.Params.MetaRootPath, []string{proxynode.Params.EtcdAddress}, timeout) if err != nil { return err } diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 32931fab8e..44166ab0e1 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -106,7 +106,7 @@ func (s *Server) init() error { // --- QueryService --- log.Debug("QueryService", zap.String("address", Params.QueryServiceAddress)) log.Debug("Init Query service client ...") - queryService, err := qsc.NewClient(Params.QueryServiceAddress, 20*time.Second) + queryService, err := qsc.NewClient(ctx, Params.QueryServiceAddress, qn.Params.MetaRootPath, []string{qn.Params.EtcdAddress}, 20*time.Second) if err != nil { panic(err) } diff --git a/internal/distributed/queryservice/client/client.go b/internal/distributed/queryservice/client/client.go index 38bd9ab185..be47c4d54c 100644 --- a/internal/distributed/queryservice/client/client.go +++ b/internal/distributed/queryservice/client/client.go @@ -13,8 +13,12 @@ package grpcqueryserviceclient import ( "context" + "fmt" "time" + "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "go.uber.org/zap" @@ -28,22 +32,43 @@ import ( ) type Client struct { + ctx context.Context grpcClient querypb.QueryServiceClient conn *grpc.ClientConn - addr string - timeout time.Duration - retry int + addr string + timeout time.Duration + sess *sessionutil.Session + reconnTry int + recallTry int } -func NewClient(address string, timeout time.Duration) (*Client, error) { +func getQueryServiceAddress(sess *sessionutil.Session) (string, error) { + key := typeutil.QueryServiceRole + msess, _, err := sess.GetSessions(key) + if err != nil { + return "", err + } + ms, ok := msess[key] + if !ok { + return "", fmt.Errorf("number of master service is incorrect, %d", len(msess)) + } + return ms.Address, nil +} + +// NewClient creates a client for queryservice grpc call. +func NewClient(ctx context.Context, address, metaRootPath string, etcdAddr []string, timeout time.Duration) (*Client, error) { + sess := sessionutil.NewSession(context.Background(), metaRootPath, etcdAddr) return &Client{ + ctx: ctx, grpcClient: nil, conn: nil, addr: address, timeout: timeout, - retry: 300, + reconnTry: 10, + recallTry: 3, + sess: sess, }, nil } @@ -51,25 +76,84 @@ func (c *Client) Init() error { tracer := opentracing.GlobalTracer() ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() - var err error - for i := 0; i < c.retry; i++ { - if c.conn, err = grpc.DialContext(ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithUnaryInterceptor( - otgrpc.OpenTracingClientInterceptor(tracer)), - grpc.WithStreamInterceptor( - otgrpc.OpenTracingStreamClientInterceptor(tracer))); err == nil { - break + if c.addr != "" { + connectGrpcFunc := func() error { + log.Debug("queryservice connect ", zap.String("address", c.addr)) + conn, err := grpc.DialContext(ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithUnaryInterceptor( + otgrpc.OpenTracingClientInterceptor(tracer)), + grpc.WithStreamInterceptor( + otgrpc.OpenTracingStreamClientInterceptor(tracer))) + if err != nil { + return err + } + c.conn = conn + return nil } - } - if err != nil { - return err + err := retry.Retry(100000, time.Millisecond*200, connectGrpcFunc) + if err != nil { + return err + } + } else { + return c.reconnect() } c.grpcClient = querypb.NewQueryServiceClient(c.conn) log.Debug("connected to queryService", zap.String("queryService", c.addr)) return nil } +func (c *Client) reconnect() error { + tracer := opentracing.GlobalTracer() + var err error + getQueryServiceAddressFn := func() error { + c.addr, err = getQueryServiceAddress(c.sess) + if err != nil { + return err + } + return nil + } + err = retry.Retry(c.reconnTry, 3*time.Second, getQueryServiceAddressFn) + if err != nil { + return err + } + connectGrpcFunc := func() error { + log.Debug("QueryService connect ", zap.String("address", c.addr)) + conn, err := grpc.DialContext(c.ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithUnaryInterceptor( + otgrpc.OpenTracingClientInterceptor(tracer)), + grpc.WithStreamInterceptor( + otgrpc.OpenTracingStreamClientInterceptor(tracer))) + if err != nil { + return err + } + c.conn = conn + return nil + } + + err = retry.Retry(c.reconnTry, 500*time.Millisecond, connectGrpcFunc) + if err != nil { + return err + } + c.grpcClient = querypb.NewQueryServiceClient(c.conn) + return nil +} +func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { + ret, err := caller() + if err == nil { + return ret, nil + } + for i := 0; i < c.recallTry; i++ { + err = c.reconnect() + if err == nil { + ret, err = caller() + if err == nil { + return ret, nil + } + } + } + return ret, err +} func (c *Client) Start() error { return nil @@ -85,53 +169,92 @@ func (c *Client) Register() error { } func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + }) + return ret.(*internalpb.ComponentStates), err } func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } func (c *Client) RegisterNode(ctx context.Context, req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) { - return c.grpcClient.RegisterNode(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.RegisterNode(ctx, req) + }) + return ret.(*querypb.RegisterNodeResponse), err } func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return c.grpcClient.ShowCollections(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowCollections(ctx, req) + }) + return ret.(*querypb.ShowCollectionsResponse), err } func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { - return c.grpcClient.LoadCollection(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.LoadCollection(ctx, req) + }) + return ret.(*commonpb.Status), err } func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - return c.grpcClient.ReleaseCollection(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ReleaseCollection(ctx, req) + }) + return ret.(*commonpb.Status), err } func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return c.grpcClient.ShowPartitions(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowPartitions(ctx, req) + }) + return ret.(*querypb.ShowPartitionsResponse), err } func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - return c.grpcClient.LoadPartitions(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.LoadPartitions(ctx, req) + }) + return ret.(*commonpb.Status), err } func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - return c.grpcClient.ReleasePartitions(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ReleasePartitions(ctx, req) + }) + return ret.(*commonpb.Status), err } func (c *Client) CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error) { - return c.grpcClient.CreateQueryChannel(ctx, &querypb.CreateQueryChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.CreateQueryChannel(ctx, &querypb.CreateQueryChannelRequest{}) + }) + return ret.(*querypb.CreateQueryChannelResponse), err } func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { - return c.grpcClient.GetPartitionStates(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetPartitionStates(ctx, req) + }) + return ret.(*querypb.GetPartitionStatesResponse), err } func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - return c.grpcClient.GetSegmentInfo(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetSegmentInfo(ctx, req) + }) + return ret.(*querypb.GetSegmentInfoResponse), err }