From c982e368dd2763ef875b78d402d5b6425a3cf26b Mon Sep 17 00:00:00 2001 From: groot Date: Mon, 27 Sep 2021 19:00:22 +0800 Subject: [PATCH] Add unittest for distributed/datanode (#8695) Signed-off-by: yhmo --- .../distributed/datacoord/client/client.go | 15 +- .../datacoord/client/client_test.go | 182 ++++++++++++++++++ .../distributed/datanode/client/client.go | 17 +- .../datanode/client/client_test.go | 106 ++++++++++ 4 files changed, 313 insertions(+), 7 deletions(-) create mode 100644 internal/distributed/datacoord/client/client_test.go create mode 100644 internal/distributed/datanode/client/client_test.go diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index 16425127f1..4623024994 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -45,9 +45,15 @@ type Client struct { sess *sessionutil.Session addr string + + getGrpcClient func() (datapb.DataCoordClient, error) } -func (c *Client) getGrpcClient() (datapb.DataCoordClient, error) { +func (c *Client) setGetGrpcClientFunc() { + c.getGrpcClient = c.getGrpcClientFunc +} + +func (c *Client) getGrpcClientFunc() (datapb.DataCoordClient, error) { c.grpcClientMtx.RLock() if c.grpcClient != nil { defer c.grpcClientMtx.RUnlock() @@ -106,11 +112,14 @@ func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*C return nil, err } ctx, cancel := context.WithCancel(ctx) - return &Client{ + client := &Client{ ctx: ctx, cancel: cancel, sess: sess, - }, nil + } + + client.setGetGrpcClientFunc() + return client, nil } func (c *Client) Init() error { diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go new file mode 100644 index 0000000000..75cf1cfb8a --- /dev/null +++ b/internal/distributed/datacoord/client/client_test.go @@ -0,0 +1,182 @@ +package grpcdatacoordclient + +import ( + "context" + "errors" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +type MockDataCoordClient struct { + err error +} + +func (m *MockDataCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.err +} + +func (m *MockDataCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockDataCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockDataCoordClient) Flush(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { + return &datapb.FlushResponse{}, m.err +} + +func (m *MockDataCoordClient) AssignSegmentID(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { + return &datapb.AssignSegmentIDResponse{}, m.err +} + +func (m *MockDataCoordClient) GetSegmentInfo(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { + return &datapb.GetSegmentInfoResponse{}, m.err +} + +func (m *MockDataCoordClient) GetSegmentStates(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { + return &datapb.GetSegmentStatesResponse{}, m.err +} + +func (m *MockDataCoordClient) GetInsertBinlogPaths(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { + return &datapb.GetInsertBinlogPathsResponse{}, m.err +} + +func (m *MockDataCoordClient) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { + return &datapb.GetCollectionStatisticsResponse{}, m.err +} + +func (m *MockDataCoordClient) GetPartitionStatistics(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { + return &datapb.GetPartitionStatisticsResponse{}, m.err +} + +func (m *MockDataCoordClient) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockDataCoordClient) SaveBinlogPaths(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockDataCoordClient) GetRecoveryInfo(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { + return &datapb.GetRecoveryInfoResponse{}, m.err +} + +func (m *MockDataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { + return &datapb.GetFlushedSegmentsResponse{}, m.err +} + +func (m *MockDataCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.err +} + +func Test_NewClient(t *testing.T) { + proxy.Params.InitOnce() + + ctx := context.Background() + client, err := NewClient(ctx, proxy.Params.MetaRootPath, proxy.Params.EtcdEndpoints) + assert.Nil(t, err) + assert.NotNil(t, client) + + err = client.Init() + assert.Nil(t, err) + + err = client.Start() + assert.Nil(t, err) + + err = client.Register() + assert.Nil(t, err) + + checkFunc := func(retNotNil bool) { + retCheck := func(notNil bool, ret interface{}, err error) { + if notNil { + assert.NotNil(t, ret) + assert.Nil(t, err) + } else { + assert.Nil(t, ret) + assert.NotNil(t, err) + } + } + + r1, err := client.GetComponentStates(ctx) + retCheck(retNotNil, r1, err) + + r2, err := client.GetTimeTickChannel(ctx) + retCheck(retNotNil, r2, err) + + r3, err := client.GetStatisticsChannel(ctx) + retCheck(retNotNil, r3, err) + + r4, err := client.Flush(ctx, nil) + retCheck(retNotNil, r4, err) + + r5, err := client.AssignSegmentID(ctx, nil) + retCheck(retNotNil, r5, err) + + r6, err := client.GetSegmentInfo(ctx, nil) + retCheck(retNotNil, r6, err) + + r7, err := client.GetSegmentStates(ctx, nil) + retCheck(retNotNil, r7, err) + + r8, err := client.GetInsertBinlogPaths(ctx, nil) + retCheck(retNotNil, r8, err) + + r9, err := client.GetCollectionStatistics(ctx, nil) + retCheck(retNotNil, r9, err) + + r10, err := client.GetPartitionStatistics(ctx, nil) + retCheck(retNotNil, r10, err) + + r11, err := client.GetSegmentInfoChannel(ctx) + retCheck(retNotNil, r11, err) + + // r12, err := client.SaveBinlogPaths(ctx, nil) + // retCheck(retNotNil, r12, err) + + r13, err := client.GetRecoveryInfo(ctx, nil) + retCheck(retNotNil, r13, err) + + r14, err := client.GetFlushedSegments(ctx, nil) + retCheck(retNotNil, r14, err) + + r15, err := client.GetMetrics(ctx, nil) + retCheck(retNotNil, r15, err) + } + + client.getGrpcClient = func() (datapb.DataCoordClient, error) { + return &MockDataCoordClient{err: nil}, errors.New("dummy") + } + checkFunc(false) + + // special case since this method didn't use recall() + ret, err := client.SaveBinlogPaths(ctx, nil) + assert.NotNil(t, ret) + assert.Nil(t, err) + + client.getGrpcClient = func() (datapb.DataCoordClient, error) { + return &MockDataCoordClient{err: errors.New("dummy")}, nil + } + checkFunc(false) + + // special case since this method didn't use recall() + ret, err = client.SaveBinlogPaths(ctx, nil) + assert.NotNil(t, ret) + assert.NotNil(t, err) + + client.getGrpcClient = func() (datapb.DataCoordClient, error) { + return &MockDataCoordClient{err: nil}, nil + } + checkFunc(true) + + err = client.Stop() + assert.Nil(t, err) +} diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 50a0b4bbe7..3cf0168257 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -45,9 +45,15 @@ type Client struct { addr string retryOptions []retry.Option + + getGrpcClient func() (datapb.DataNodeClient, error) } -func (c *Client) getGrpcClient() (datapb.DataNodeClient, error) { +func (c *Client) setGetGrpcClientFunc() { + c.getGrpcClient = c.getGrpcClientFunc +} + +func (c *Client) getGrpcClientFunc() (datapb.DataNodeClient, error) { c.grpcMtx.RLock() if c.grpc != nil { defer c.grpcMtx.RUnlock() @@ -89,17 +95,20 @@ func NewClient(ctx context.Context, addr string, retryOptions ...retry.Option) ( } ctx, cancel := context.WithCancel(ctx) - return &Client{ + client := &Client{ ctx: ctx, cancel: cancel, addr: addr, retryOptions: retryOptions, - }, nil + } + + client.setGetGrpcClientFunc() + return client, nil } func (c *Client) Init() error { Params.Init() - return c.connect(retry.Attempts(20)) + return nil } func (c *Client) connect(retryOptions ...retry.Option) error { diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go new file mode 100644 index 0000000000..09af7c45d3 --- /dev/null +++ b/internal/distributed/datanode/client/client_test.go @@ -0,0 +1,106 @@ +package grpcdatanodeclient + +import ( + "context" + "errors" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +type MockDataNodeClient struct { + err error +} + +func (m *MockDataNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.err +} + +func (m *MockDataNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockDataNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.err +} + +func Test_NewClient(t *testing.T) { + proxy.Params.InitOnce() + + ctx := context.Background() + client, err := NewClient(ctx, "") + assert.Nil(t, client) + assert.NotNil(t, err) + + client, err = NewClient(ctx, "test") + assert.Nil(t, err) + assert.NotNil(t, client) + + err = client.Init() + assert.Nil(t, err) + + err = client.Start() + assert.Nil(t, err) + + err = client.Register() + assert.Nil(t, err) + + checkFunc := func(retNotNil bool) { + retCheck := func(notNil bool, ret interface{}, err error) { + if notNil { + assert.NotNil(t, ret) + assert.Nil(t, err) + } else { + assert.Nil(t, ret) + assert.NotNil(t, err) + } + } + + r1, err := client.GetComponentStates(ctx) + retCheck(retNotNil, r1, err) + + r2, err := client.GetStatisticsChannel(ctx) + retCheck(retNotNil, r2, err) + + r3, err := client.WatchDmChannels(ctx, nil) + retCheck(retNotNil, r3, err) + + r4, err := client.FlushSegments(ctx, nil) + retCheck(retNotNil, r4, err) + + r5, err := client.GetMetrics(ctx, nil) + retCheck(retNotNil, r5, err) + } + + client.getGrpcClient = func() (datapb.DataNodeClient, error) { + return &MockDataNodeClient{err: nil}, errors.New("dummy") + } + checkFunc(false) + + client.getGrpcClient = func() (datapb.DataNodeClient, error) { + return &MockDataNodeClient{err: errors.New("dummy")}, nil + } + checkFunc(false) + + client.getGrpcClient = func() (datapb.DataNodeClient, error) { + return &MockDataNodeClient{err: nil}, nil + } + checkFunc(true) + + err = client.Stop() + assert.Nil(t, err) +}