mirror of https://github.com/milvus-io/milvus.git
fix: wrong context passing into NewClient, error handling lost in session_util (#30817)
issue: #30799 Signed-off-by: chyezh <chyezh@outlook.com>pull/30788/head
parent
6548182a4a
commit
77477d6340
|
@ -135,15 +135,9 @@ func TestSegmentAllocator2(t *testing.T) {
|
|||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
total := uint32(0)
|
||||
for i := 0; i < 10; i++ {
|
||||
ret, err := segAllocator.GetSegmentID(1, 1, "abc", 1, 200)
|
||||
|
@ -154,7 +148,6 @@ func TestSegmentAllocator2(t *testing.T) {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2())
|
||||
assert.Error(t, err)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSegmentAllocator3(t *testing.T) {
|
||||
|
|
|
@ -340,7 +340,7 @@ func (node *QueryNode) Init() error {
|
|||
}
|
||||
}
|
||||
|
||||
client, err := grpcquerynodeclient.NewClient(ctx, addr, nodeID)
|
||||
client, err := grpcquerynodeclient.NewClient(node.ctx, addr, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -371,6 +371,7 @@ func (c *ClientBase[T]) verifySession(ctx context.Context) error {
|
|||
if getSessionErr != nil {
|
||||
// Only log but not handle this error as it is an auxiliary logic
|
||||
log.Warn("fail to get session", zap.Error(getSessionErr))
|
||||
return getSessionErr
|
||||
}
|
||||
if coordSess, exist := sessions[c.GetRole()]; exist {
|
||||
if c.GetNodeID() != coordSess.ServerID {
|
||||
|
@ -412,7 +413,7 @@ func (c *ClientBase[T]) checkGrpcErr(ctx context.Context, err error) (needRetry,
|
|||
case funcutil.IsGrpcErr(err, codes.Unimplemented):
|
||||
return false, false, merr.WrapErrServiceUnimplemented(err)
|
||||
case IsServerIDMismatchErr(err):
|
||||
if ok, err := c.checkNodeSessionExist(ctx); !ok {
|
||||
if ok := c.checkNodeSessionExist(ctx); !ok {
|
||||
// if session doesn't exist, no need to retry for datanode/indexnode/querynode/proxy
|
||||
return false, false, err
|
||||
}
|
||||
|
@ -424,16 +425,17 @@ func (c *ClientBase[T]) checkGrpcErr(ctx context.Context, err error) (needRetry,
|
|||
}
|
||||
}
|
||||
|
||||
func (c *ClientBase[T]) checkNodeSessionExist(ctx context.Context) (bool, error) {
|
||||
// checkNodeSessionExist checks if the session of the node exists.
|
||||
// If the session does not exist , it will return false, otherwise it will return true.
|
||||
func (c *ClientBase[T]) checkNodeSessionExist(ctx context.Context) bool {
|
||||
if c.isNode {
|
||||
err := c.verifySession(ctx)
|
||||
if errors.Is(err, merr.ErrNodeNotFound) {
|
||||
if err != nil {
|
||||
log.Warn("failed to verify node session", zap.Error(err))
|
||||
// stop retry
|
||||
return false, err
|
||||
}
|
||||
return !errors.Is(err, merr.ErrNodeNotFound)
|
||||
}
|
||||
return true, nil
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, error)) (any, error) {
|
||||
|
@ -461,9 +463,9 @@ func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, er
|
|||
defer cancel()
|
||||
err := retry.Do(ctx, func() error {
|
||||
if wrapper == nil {
|
||||
if ok, err := c.checkNodeSessionExist(ctx); !ok {
|
||||
if ok := c.checkNodeSessionExist(ctx); !ok {
|
||||
// if session doesn't exist, no need to reset connection for datanode/indexnode/querynode
|
||||
return retry.Unrecoverable(err)
|
||||
return retry.Unrecoverable(merr.ErrNodeNotFound)
|
||||
}
|
||||
|
||||
err := errors.Wrap(clientErr, "empty grpc client")
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
|
@ -124,7 +125,7 @@ func TestClientBase_NodeSessionNotExist(t *testing.T) {
|
|||
_, err = base.Call(ctx, func(client *mockClient) (any, error) {
|
||||
return struct{}{}, status.Errorf(codes.Unknown, merr.ErrNodeNotMatch.Error())
|
||||
})
|
||||
assert.True(t, errors.Is(err, merr.ErrNodeNotFound))
|
||||
assert.True(t, IsServerIDMismatchErr(err))
|
||||
|
||||
// test querynode/datanode/indexnode/proxy down, return unavailable error
|
||||
base.grpcClientMtx.Lock()
|
||||
|
@ -526,3 +527,45 @@ func TestClientBase_Compression(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, res.(*milvuspb.ComponentStates).GetState().GetNodeID(), randID)
|
||||
}
|
||||
|
||||
func TestVerifySession(t *testing.T) {
|
||||
base := ClientBase[*mockClient]{}
|
||||
mockSession := sessionutil.NewMockSession(t)
|
||||
expectedErr := errors.New("mocked")
|
||||
mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, expectedErr)
|
||||
base.sess = mockSession
|
||||
|
||||
ctx := context.Background()
|
||||
err := base.verifySession(ctx)
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
|
||||
base.lastSessionCheck.Store(time.Unix(0, 0))
|
||||
base.NodeID = *atomic.NewInt64(1)
|
||||
base.role = typeutil.RootCoordRole
|
||||
mockSession2 := sessionutil.NewMockSession(t)
|
||||
mockSession2.EXPECT().GetSessions(mock.Anything).Return(
|
||||
map[string]*sessionutil.Session{
|
||||
typeutil.RootCoordRole: {
|
||||
SessionRaw: sessionutil.SessionRaw{
|
||||
ServerID: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
0,
|
||||
nil,
|
||||
)
|
||||
base.sess = mockSession2
|
||||
err = base.verifySession(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
base.lastSessionCheck.Store(time.Unix(0, 0))
|
||||
base.NodeID = *atomic.NewInt64(2)
|
||||
err = base.verifySession(ctx)
|
||||
assert.ErrorIs(t, err, merr.ErrNodeNotMatch)
|
||||
|
||||
base.lastSessionCheck.Store(time.Unix(0, 0))
|
||||
base.NodeID = *atomic.NewInt64(1)
|
||||
base.role = typeutil.QueryNodeRole
|
||||
err = base.verifySession(ctx)
|
||||
assert.ErrorIs(t, err, merr.ErrNodeNotFound)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue