diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index e0971ea400..0270a47475 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -202,7 +202,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel) return lastErr } - defer lb.clientMgr.ReleaseClientRef(targetNode.nodeID) err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) if err != nil { diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index d427dd0b8e..c17388d402 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -253,7 +253,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test execute success s.lbBalancer.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) @@ -292,7 +291,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test get client failed, and retry failed, expected success s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -313,7 +311,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.Error(err) s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -334,7 +331,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec failed, then retry success s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -362,7 +358,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec timeout s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1) @@ -387,7 +382,6 @@ func (s *LBPolicySuite) TestExecute() { ctx := context.Background() mockErr := errors.New("mock error") // test all channel success - s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 71178ffd46..0aa26a78a7 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -260,7 +260,6 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { log.RatedInfo(10, "get client failed", zap.Int64("node", node), zap.Error(err)) return struct{}{}, nil } - defer b.clientMgr.ReleaseClientRef(node) resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index d91e24f37d..20cc7c51a6 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -43,7 +43,6 @@ type LookAsideBalancerSuite struct { func (suite *LookAsideBalancerSuite) SetupTest() { suite.clientMgr = NewMockShardClientManager(suite.T()) - suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.balancer = NewLookAsideBalancer(suite.clientMgr) suite.balancer.Start(context.Background()) @@ -309,7 +308,6 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { }, }, nil).Maybe() suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) { if ni.nodeID == 1 { return qn, nil @@ -373,7 +371,6 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { // test get shard client from client mgr return nil suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found")) // expected stopping the health check after failure times reaching the limit suite.Eventually(func() bool { @@ -385,7 +382,6 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything) suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ @@ -424,7 +420,6 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything) suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ diff --git a/internal/proxy/mock_shardclient_manager.go b/internal/proxy/mock_shardclient_manager.go index 878168e0e7..a9c7921efb 100644 --- a/internal/proxy/mock_shardclient_manager.go +++ b/internal/proxy/mock_shardclient_manager.go @@ -113,39 +113,6 @@ func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.C return _c } -// ReleaseClientRef provides a mock function with given fields: nodeID -func (_m *MockShardClientManager) ReleaseClientRef(nodeID int64) { - _m.Called(nodeID) -} - -// MockShardClientManager_ReleaseClientRef_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClientRef' -type MockShardClientManager_ReleaseClientRef_Call struct { - *mock.Call -} - -// ReleaseClientRef is a helper method to define mock.On call -// - nodeID int64 -func (_e *MockShardClientManager_Expecter) ReleaseClientRef(nodeID interface{}) *MockShardClientManager_ReleaseClientRef_Call { - return &MockShardClientManager_ReleaseClientRef_Call{Call: _e.mock.On("ReleaseClientRef", nodeID)} -} - -func (_c *MockShardClientManager_ReleaseClientRef_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClientRef_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) - }) - return _c -} - -func (_c *MockShardClientManager_ReleaseClientRef_Call) Return() *MockShardClientManager_ReleaseClientRef_Call { - _c.Call.Return() - return _c -} - -func (_c *MockShardClientManager_ReleaseClientRef_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClientRef_Call { - _c.Call.Return(run) - return _c -} - // SetClientCreatorFunc provides a mock function with given fields: creator func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) { _m.Called(creator) diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index 301fbc9475..83a94cf38b 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -14,6 +14,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) @@ -32,19 +33,29 @@ var errClosed = errors.New("client is closed") type shardClient struct { sync.RWMutex info nodeInfo - isClosed bool - clients []types.QueryNodeClient - idx atomic.Int64 poolSize int - pooling bool + clients []types.QueryNodeClient + creator queryNodeCreatorFunc initialized atomic.Bool - creator queryNodeCreatorFunc + isClosed bool - refCnt *atomic.Int64 + idx atomic.Int64 + lastActiveTs *atomic.Int64 + expiredDuration time.Duration +} + +func newShardClient(info nodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient { + return &shardClient{ + info: info, + creator: creator, + lastActiveTs: atomic.NewInt64(time.Now().UnixNano()), + expiredDuration: expiredDuration, + } } func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { + n.lastActiveTs.Store(time.Now().UnixNano()) if !n.initialized.Load() { n.Lock() if !n.initialized.Load() { @@ -52,7 +63,6 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err n.Unlock() return nil, err } - n.initialized.Store(true) } n.Unlock() } @@ -66,79 +76,39 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err if err != nil { return nil, err } - n.IncRef() return client, nil } } -func (n *shardClient) DecRef() bool { - if n.refCnt.Dec() == 0 { - n.Close() - return true - } - return false -} - -func (n *shardClient) IncRef() { - n.refCnt.Inc() -} - -func (n *shardClient) close() { - n.isClosed = true - - for _, client := range n.clients { - if err := client.Close(); err != nil { - log.Warn("close grpc client failed", zap.Error(err)) - } - } - n.clients = nil -} - -func (n *shardClient) Close() { - n.Lock() - defer n.Unlock() - n.close() -} - -func newShardClient(info nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { - num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() - if num <= 0 { - num = 1 - } - - return &shardClient{ - info: nodeInfo{ - nodeID: info.nodeID, - address: info.address, - }, - poolSize: num, - creator: creator, - refCnt: atomic.NewInt64(1), - }, nil -} - func (n *shardClient) initClients(ctx context.Context) error { - clients := make([]types.QueryNodeClient, 0, n.poolSize) - for i := 0; i < n.poolSize; i++ { + poolSize := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() + if poolSize <= 0 { + poolSize = 1 + } + + clients := make([]types.QueryNodeClient, 0, poolSize) + for i := 0; i < poolSize; i++ { client, err := n.creator(ctx, n.info.address, n.info.nodeID) if err != nil { // Roll back already created clients for _, c := range clients { c.Close() } + log.Info("failed to create client for node", zap.Int64("nodeID", n.info.nodeID), zap.Error(err)) return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID)) } clients = append(clients, client) } + n.initialized.Store(true) + n.poolSize = poolSize n.clients = clients return nil } -// roundRobinSelectClient selects a client in a round-robin manner func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) { - n.Lock() - defer n.Unlock() + n.RLock() + defer n.RUnlock() if n.isClosed { return nil, errClosed } @@ -152,23 +122,55 @@ func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) { return nextClient, nil } +// Notice: close client should only be called by shard client manager. and after close, the client must be removed from the manager. +// 1. the client hasn't been used for a long time +// 2. shard client manager has been closed. +func (n *shardClient) Close(force bool) bool { + n.Lock() + defer n.Unlock() + if force || n.isExpired() { + n.close() + } + + return n.isClosed +} + +func (n *shardClient) isExpired() bool { + return time.Now().UnixNano()-n.lastActiveTs.Load() > n.expiredDuration.Nanoseconds() +} + +func (n *shardClient) close() { + n.isClosed = true + + for _, client := range n.clients { + if err := client.Close(); err != nil { + log.Warn("close grpc client failed", zap.Error(err)) + } + } + n.clients = nil +} + +// roundRobinSelectClient selects a client in a round-robin manner type shardClientMgr interface { GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error) - ReleaseClientRef(nodeID int64) Close() SetClientCreatorFunc(creator queryNodeCreatorFunc) } type shardClientMgrImpl struct { - clients struct { - sync.RWMutex - data map[UniqueID]*shardClient - } + clients *typeutil.ConcurrentMap[UniqueID, *shardClient] clientCreator queryNodeCreatorFunc + closeCh chan struct{} - closeCh chan struct{} + purgeInterval time.Duration + expiredDuration time.Duration } +const ( + defaultPurgeInterval = 600 * time.Second + defaultExpiredDuration = 60 * time.Minute +) + // SessionOpt provides a way to set params in SessionManager type shardClientMgrOpt func(s shardClientMgr) @@ -183,12 +185,11 @@ func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int6 // NewShardClientMgr creates a new shardClientMgr func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl { s := &shardClientMgrImpl{ - clients: struct { - sync.RWMutex - data map[UniqueID]*shardClient - }{data: make(map[UniqueID]*shardClient)}, - clientCreator: defaultQueryNodeClientCreator, - closeCh: make(chan struct{}), + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: defaultQueryNodeClientCreator, + closeCh: make(chan struct{}), + purgeInterval: defaultPurgeInterval, + expiredDuration: defaultExpiredDuration, } for _, opt := range options { opt(s) @@ -203,66 +204,44 @@ func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) } func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) { - c.clients.RLock() - client, ok := c.clients.data[info.nodeID] - c.clients.RUnlock() - - if !ok { - c.clients.Lock() - // Check again after acquiring the lock - client, ok = c.clients.data[info.nodeID] - if !ok { - // Create a new client if it doesn't exist - newClient, err := newShardClient(info, c.clientCreator) - if err != nil { - c.clients.Unlock() - return nil, err - } - c.clients.data[info.nodeID] = newClient - client = newClient - } - c.clients.Unlock() - } - + client, _ := c.clients.GetOrInsert(info.nodeID, newShardClient(info, c.clientCreator, c.expiredDuration)) return client.getClient(ctx) } +// PurgeClient purges client if it is not used for a long time func (c *shardClientMgrImpl) PurgeClient() { - ticker := time.NewTicker(600 * time.Second) + ticker := time.NewTicker(c.purgeInterval) defer ticker.Stop() + for { select { case <-c.closeCh: return case <-ticker.C: shardLocations := globalMetaCache.ListShardLocation() - c.clients.Lock() - for nodeID, client := range c.clients.data { - if _, ok := shardLocations[nodeID]; !ok { - client.DecRef() - delete(c.clients.data, nodeID) + c.clients.Range(func(key UniqueID, value *shardClient) bool { + if _, ok := shardLocations[key]; !ok { + // if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it + if value.isExpired() { + closed := value.Close(false) + if closed { + c.clients.Remove(key) + log.Info("remove idle node client", zap.Int64("nodeID", key)) + } + } } - } - c.clients.Unlock() + return true + }) } } } -func (c *shardClientMgrImpl) ReleaseClientRef(nodeID int64) { - c.clients.RLock() - defer c.clients.RUnlock() - if client, ok := c.clients.data[nodeID]; ok { - client.DecRef() - } -} - // Close release clients func (c *shardClientMgrImpl) Close() { - c.clients.Lock() - defer c.clients.Unlock() close(c.closeCh) - for _, s := range c.clients.data { - s.Close() - } - c.clients.data = make(map[UniqueID]*shardClient) + c.clients.Range(func(key UniqueID, value *shardClient) bool { + value.Close(true) + c.clients.Remove(key) + return true + }) } diff --git a/internal/proxy/shard_client_test.go b/internal/proxy/shard_client_test.go index 272b10e06e..ddc8308954 100644 --- a/internal/proxy/shard_client_test.go +++ b/internal/proxy/shard_client_test.go @@ -3,12 +3,15 @@ package proxy import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestShardClientMgr(t *testing.T) { @@ -28,10 +31,8 @@ func TestShardClientMgr(t *testing.T) { _, err := mgr.GetClient(ctx, nodeInfo) assert.Nil(t, err) - mgr.ReleaseClientRef(1) - assert.Equal(t, len(mgr.clients.data), 1) mgr.Close() - assert.Equal(t, len(mgr.clients.data), 0) + assert.Equal(t, mgr.clients.Len(), 0) } func TestShardClient(t *testing.T) { @@ -40,27 +41,120 @@ func TestShardClient(t *testing.T) { } qn := mocks.NewMockQueryNodeClient(t) - qn.EXPECT().Close().Return(nil) + qn.EXPECT().Close().Return(nil).Maybe() creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { return qn, nil } - shardClient, err := newShardClient(nodeInfo, creator) - assert.Nil(t, err) + shardClient := newShardClient(nodeInfo, creator, 3*time.Second) assert.Equal(t, len(shardClient.clients), 0) - assert.Equal(t, int64(1), shardClient.refCnt.Load()) assert.Equal(t, false, shardClient.initialized.Load()) + assert.Equal(t, false, shardClient.isClosed) ctx := context.Background() - _, err = shardClient.getClient(ctx) + _, err := shardClient.getClient(ctx) assert.Nil(t, err) assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()) - assert.Equal(t, int64(2), shardClient.refCnt.Load()) - assert.Equal(t, true, shardClient.initialized.Load()) - shardClient.DecRef() - assert.Equal(t, int64(1), shardClient.refCnt.Load()) - - shardClient.DecRef() - assert.Equal(t, int64(0), shardClient.refCnt.Load()) - assert.Equal(t, true, shardClient.isClosed) + // test close + closed := shardClient.Close(false) + assert.False(t, closed) + closed = shardClient.Close(true) + assert.True(t, closed) +} + +func TestPurgeClient(t *testing.T) { + node := nodeInfo{ + nodeID: 1, + } + + returnEmptyResult := atomic.NewBool(false) + + cache := NewMockCache(t) + cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo { + if returnEmptyResult.Load() { + return map[int64]nodeInfo{} + } + return map[int64]nodeInfo{ + 1: node, + } + }) + globalMetaCache = cache + + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil).Maybe() + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } + + s := &shardClientMgrImpl{ + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: creator, + closeCh: make(chan struct{}), + purgeInterval: 1 * time.Second, + expiredDuration: 3 * time.Second, + } + + go s.PurgeClient() + defer s.Close() + _, err := s.GetClient(context.Background(), node) + assert.Nil(t, err) + qnClient, ok := s.clients.Get(1) + assert.True(t, ok) + assert.True(t, qnClient.lastActiveTs.Load() > 0) + + time.Sleep(2 * time.Second) + // expected client should not been purged before expiredDuration + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds()) + + _, err = s.GetClient(context.Background(), node) + assert.Nil(t, err) + time.Sleep(2 * time.Second) + // GetClient should refresh lastActiveTs, expected client should not be purged + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds()) + + time.Sleep(2 * time.Second) + // client reach the expiredDuration, expected client should not be purged + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds()) + + returnEmptyResult.Store(true) + time.Sleep(2 * time.Second) + // remove client from shard location, expected client should be purged + assert.Equal(t, s.clients.Len(), 0) +} + +func BenchmarkShardClientMgr(b *testing.B) { + node := nodeInfo{ + nodeID: 1, + } + cache := NewMockCache(b) + cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{ + 1: node, + }).Maybe() + globalMetaCache = cache + qn := mocks.NewMockQueryNodeClient(b) + qn.EXPECT().Close().Return(nil).Maybe() + + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } + s := &shardClientMgrImpl{ + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: creator, + closeCh: make(chan struct{}), + purgeInterval: 1 * time.Second, + expiredDuration: 10 * time.Second, + } + go s.PurgeClient() + defer s.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := s.GetClient(context.Background(), node) + assert.Nil(b, err) + } + }) } diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index 28937fe8e5..f3a047bd38 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -39,7 +39,6 @@ func RoundRobinPolicy( combineErr = merr.Combine(combineErr, err) continue } - defer mgr.ReleaseClientRef(target.nodeID) err = query(ctx, target.nodeID, qn, channel) if err != nil { log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 36b26866b4..ea90adf597 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -79,7 +79,6 @@ func TestQueryTask_all(t *testing.T) { }, nil).Maybe() mgr := NewMockShardClientManager(t) - mgr.EXPECT().ReleaseClientRef(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() lb := NewLBPolicyImpl(mgr) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 606e277f39..3d89adcae8 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2111,7 +2111,6 @@ func TestSearchTask_ErrExecute(t *testing.T) { qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := NewMockShardClientManager(t) - mgr.EXPECT().ReleaseClientRef(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() lb := NewLBPolicyImpl(mgr) diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 54a084b71e..e3d3786cfe 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -80,7 +80,6 @@ func (s *StatisticTaskSuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(s.T()) - mgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() s.lb = NewLBPolicyImpl(mgr)