From 77477d6340a7a075596170b882050bd740c51f81 Mon Sep 17 00:00:00 2001 From: chyezh Date: Wed, 28 Feb 2024 10:40:09 +0800 Subject: [PATCH] fix: wrong context passing into NewClient, error handling lost in session_util (#30817) issue: #30799 Signed-off-by: chyezh --- internal/proxy/segment_test.go | 9 +---- internal/querynodev2/server.go | 2 +- internal/util/grpcclient/client.go | 18 +++++----- internal/util/grpcclient/client_test.go | 45 ++++++++++++++++++++++++- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/internal/proxy/segment_test.go b/internal/proxy/segment_test.go index 0d7d8c9485..eeafe3b317 100644 --- a/internal/proxy/segment_test.go +++ b/internal/proxy/segment_test.go @@ -135,15 +135,9 @@ func TestSegmentAllocator2(t *testing.T) { dataCoord.expireTime = Timestamp(500) segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2) assert.NoError(t, err) - wg := &sync.WaitGroup{} segAllocator.Start() + defer segAllocator.Close() - wg.Add(1) - go func(group *sync.WaitGroup) { - defer group.Done() - time.Sleep(100 * time.Millisecond) - segAllocator.Close() - }(wg) total := uint32(0) for i := 0; i < 10; i++ { ret, err := segAllocator.GetSegmentID(1, 1, "abc", 1, 200) @@ -154,7 +148,6 @@ func TestSegmentAllocator2(t *testing.T) { time.Sleep(50 * time.Millisecond) _, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2()) assert.Error(t, err) - wg.Wait() } func TestSegmentAllocator3(t *testing.T) { diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index aa69a535a3..ae142aa486 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -340,7 +340,7 @@ func (node *QueryNode) Init() error { } } - client, err := grpcquerynodeclient.NewClient(ctx, addr, nodeID) + client, err := grpcquerynodeclient.NewClient(node.ctx, addr, nodeID) if err != nil { return nil, err } diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 71a031fa79..3622f18a34 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -371,6 +371,7 @@ func (c *ClientBase[T]) verifySession(ctx context.Context) error { if getSessionErr != nil { // Only log but not handle this error as it is an auxiliary logic log.Warn("fail to get session", zap.Error(getSessionErr)) + return getSessionErr } if coordSess, exist := sessions[c.GetRole()]; exist { if c.GetNodeID() != coordSess.ServerID { @@ -412,7 +413,7 @@ func (c *ClientBase[T]) checkGrpcErr(ctx context.Context, err error) (needRetry, case funcutil.IsGrpcErr(err, codes.Unimplemented): return false, false, merr.WrapErrServiceUnimplemented(err) case IsServerIDMismatchErr(err): - if ok, err := c.checkNodeSessionExist(ctx); !ok { + if ok := c.checkNodeSessionExist(ctx); !ok { // if session doesn't exist, no need to retry for datanode/indexnode/querynode/proxy return false, false, err } @@ -424,16 +425,17 @@ func (c *ClientBase[T]) checkGrpcErr(ctx context.Context, err error) (needRetry, } } -func (c *ClientBase[T]) checkNodeSessionExist(ctx context.Context) (bool, error) { +// checkNodeSessionExist checks if the session of the node exists. +// If the session does not exist , it will return false, otherwise it will return true. +func (c *ClientBase[T]) checkNodeSessionExist(ctx context.Context) bool { if c.isNode { err := c.verifySession(ctx) - if errors.Is(err, merr.ErrNodeNotFound) { + if err != nil { log.Warn("failed to verify node session", zap.Error(err)) - // stop retry - return false, err } + return !errors.Is(err, merr.ErrNodeNotFound) } - return true, nil + return true } func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, error)) (any, error) { @@ -461,9 +463,9 @@ func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, er defer cancel() err := retry.Do(ctx, func() error { if wrapper == nil { - if ok, err := c.checkNodeSessionExist(ctx); !ok { + if ok := c.checkNodeSessionExist(ctx); !ok { // if session doesn't exist, no need to reset connection for datanode/indexnode/querynode - return retry.Unrecoverable(err) + return retry.Unrecoverable(merr.ErrNodeNotFound) } err := errors.Wrap(clientErr, "empty grpc client") diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index fa691e61e7..045c9b39de 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -29,6 +29,7 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.uber.org/atomic" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/examples/helloworld/helloworld" @@ -124,7 +125,7 @@ func TestClientBase_NodeSessionNotExist(t *testing.T) { _, err = base.Call(ctx, func(client *mockClient) (any, error) { return struct{}{}, status.Errorf(codes.Unknown, merr.ErrNodeNotMatch.Error()) }) - assert.True(t, errors.Is(err, merr.ErrNodeNotFound)) + assert.True(t, IsServerIDMismatchErr(err)) // test querynode/datanode/indexnode/proxy down, return unavailable error base.grpcClientMtx.Lock() @@ -526,3 +527,45 @@ func TestClientBase_Compression(t *testing.T) { assert.NoError(t, err) assert.Equal(t, res.(*milvuspb.ComponentStates).GetState().GetNodeID(), randID) } + +func TestVerifySession(t *testing.T) { + base := ClientBase[*mockClient]{} + mockSession := sessionutil.NewMockSession(t) + expectedErr := errors.New("mocked") + mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, expectedErr) + base.sess = mockSession + + ctx := context.Background() + err := base.verifySession(ctx) + assert.ErrorIs(t, err, expectedErr) + + base.lastSessionCheck.Store(time.Unix(0, 0)) + base.NodeID = *atomic.NewInt64(1) + base.role = typeutil.RootCoordRole + mockSession2 := sessionutil.NewMockSession(t) + mockSession2.EXPECT().GetSessions(mock.Anything).Return( + map[string]*sessionutil.Session{ + typeutil.RootCoordRole: { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + }, + }, + }, + 0, + nil, + ) + base.sess = mockSession2 + err = base.verifySession(ctx) + assert.NoError(t, err) + + base.lastSessionCheck.Store(time.Unix(0, 0)) + base.NodeID = *atomic.NewInt64(2) + err = base.verifySession(ctx) + assert.ErrorIs(t, err, merr.ErrNodeNotMatch) + + base.lastSessionCheck.Store(time.Unix(0, 0)) + base.NodeID = *atomic.NewInt64(1) + base.role = typeutil.QueryNodeRole + err = base.verifySession(ctx) + assert.ErrorIs(t, err, merr.ErrNodeNotFound) +}