mirror of https://github.com/milvus-io/milvus.git
fix: Get shard client failed by client is closed (#37729)
issue: #37718 This PR refine the shard client ref counter, dec ref counter won't release client anymore, and only permit shard client manager to remove client. --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>pull/37779/head
parent
226fe900e7
commit
261212ee4a
|
@ -202,7 +202,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
|||
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
|
||||
return lastErr
|
||||
}
|
||||
defer lb.clientMgr.ReleaseClientRef(targetNode.nodeID)
|
||||
|
||||
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
|
||||
if err != nil {
|
||||
|
|
|
@ -253,7 +253,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
|
||||
// test execute success
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
|
@ -292,7 +291,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
|
||||
// test get client failed, and retry failed, expected success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
|
@ -313,7 +311,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.Error(err)
|
||||
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
|
@ -334,7 +331,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
|
||||
// test exec failed, then retry success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
|
@ -362,7 +358,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
// test exec timeout
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1)
|
||||
|
@ -387,7 +382,6 @@ func (s *LBPolicySuite) TestExecute() {
|
|||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
// test all channel success
|
||||
s.mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
|
|
|
@ -260,7 +260,6 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
|
|||
log.RatedInfo(10, "get client failed", zap.Int64("node", node), zap.Error(err))
|
||||
return struct{}{}, nil
|
||||
}
|
||||
defer b.clientMgr.ReleaseClientRef(node)
|
||||
|
||||
resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{})
|
||||
if err != nil {
|
||||
|
|
|
@ -43,7 +43,6 @@ type LookAsideBalancerSuite struct {
|
|||
|
||||
func (suite *LookAsideBalancerSuite) SetupTest() {
|
||||
suite.clientMgr = NewMockShardClientManager(suite.T())
|
||||
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
|
||||
suite.balancer = NewLookAsideBalancer(suite.clientMgr)
|
||||
suite.balancer.Start(context.Background())
|
||||
|
||||
|
@ -309,7 +308,6 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
|
|||
},
|
||||
}, nil).Maybe()
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) {
|
||||
if ni.nodeID == 1 {
|
||||
return qn, nil
|
||||
|
@ -373,7 +371,6 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() {
|
|||
|
||||
// test get shard client from client mgr return nil
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found"))
|
||||
// expected stopping the health check after failure times reaching the limit
|
||||
suite.Eventually(func() bool {
|
||||
|
@ -385,7 +382,6 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
|
|||
// mock qn down for a while and then recover
|
||||
qn3 := mocks.NewMockQueryNodeClient(suite.T())
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil)
|
||||
qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
|
@ -424,7 +420,6 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() {
|
|||
// mock qn down for a while and then recover
|
||||
qn3 := mocks.NewMockQueryNodeClient(suite.T())
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil)
|
||||
qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
|
|
|
@ -113,39 +113,6 @@ func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.C
|
|||
return _c
|
||||
}
|
||||
|
||||
// ReleaseClientRef provides a mock function with given fields: nodeID
|
||||
func (_m *MockShardClientManager) ReleaseClientRef(nodeID int64) {
|
||||
_m.Called(nodeID)
|
||||
}
|
||||
|
||||
// MockShardClientManager_ReleaseClientRef_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClientRef'
|
||||
type MockShardClientManager_ReleaseClientRef_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReleaseClientRef is a helper method to define mock.On call
|
||||
// - nodeID int64
|
||||
func (_e *MockShardClientManager_Expecter) ReleaseClientRef(nodeID interface{}) *MockShardClientManager_ReleaseClientRef_Call {
|
||||
return &MockShardClientManager_ReleaseClientRef_Call{Call: _e.mock.On("ReleaseClientRef", nodeID)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ReleaseClientRef_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClientRef_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ReleaseClientRef_Call) Return() *MockShardClientManager_ReleaseClientRef_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ReleaseClientRef_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClientRef_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClientCreatorFunc provides a mock function with given fields: creator
|
||||
func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
|
||||
_m.Called(creator)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
|
||||
|
@ -32,19 +33,29 @@ var errClosed = errors.New("client is closed")
|
|||
type shardClient struct {
|
||||
sync.RWMutex
|
||||
info nodeInfo
|
||||
isClosed bool
|
||||
clients []types.QueryNodeClient
|
||||
idx atomic.Int64
|
||||
poolSize int
|
||||
pooling bool
|
||||
clients []types.QueryNodeClient
|
||||
creator queryNodeCreatorFunc
|
||||
|
||||
initialized atomic.Bool
|
||||
creator queryNodeCreatorFunc
|
||||
isClosed bool
|
||||
|
||||
refCnt *atomic.Int64
|
||||
idx atomic.Int64
|
||||
lastActiveTs *atomic.Int64
|
||||
expiredDuration time.Duration
|
||||
}
|
||||
|
||||
func newShardClient(info nodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient {
|
||||
return &shardClient{
|
||||
info: info,
|
||||
creator: creator,
|
||||
lastActiveTs: atomic.NewInt64(time.Now().UnixNano()),
|
||||
expiredDuration: expiredDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) {
|
||||
n.lastActiveTs.Store(time.Now().UnixNano())
|
||||
if !n.initialized.Load() {
|
||||
n.Lock()
|
||||
if !n.initialized.Load() {
|
||||
|
@ -52,7 +63,6 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err
|
|||
n.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
n.initialized.Store(true)
|
||||
}
|
||||
n.Unlock()
|
||||
}
|
||||
|
@ -66,79 +76,39 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.IncRef()
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *shardClient) DecRef() bool {
|
||||
if n.refCnt.Dec() == 0 {
|
||||
n.Close()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *shardClient) IncRef() {
|
||||
n.refCnt.Inc()
|
||||
}
|
||||
|
||||
func (n *shardClient) close() {
|
||||
n.isClosed = true
|
||||
|
||||
for _, client := range n.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Warn("close grpc client failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
n.clients = nil
|
||||
}
|
||||
|
||||
func (n *shardClient) Close() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
n.close()
|
||||
}
|
||||
|
||||
func newShardClient(info nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) {
|
||||
num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()
|
||||
if num <= 0 {
|
||||
num = 1
|
||||
}
|
||||
|
||||
return &shardClient{
|
||||
info: nodeInfo{
|
||||
nodeID: info.nodeID,
|
||||
address: info.address,
|
||||
},
|
||||
poolSize: num,
|
||||
creator: creator,
|
||||
refCnt: atomic.NewInt64(1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *shardClient) initClients(ctx context.Context) error {
|
||||
clients := make([]types.QueryNodeClient, 0, n.poolSize)
|
||||
for i := 0; i < n.poolSize; i++ {
|
||||
poolSize := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()
|
||||
if poolSize <= 0 {
|
||||
poolSize = 1
|
||||
}
|
||||
|
||||
clients := make([]types.QueryNodeClient, 0, poolSize)
|
||||
for i := 0; i < poolSize; i++ {
|
||||
client, err := n.creator(ctx, n.info.address, n.info.nodeID)
|
||||
if err != nil {
|
||||
// Roll back already created clients
|
||||
for _, c := range clients {
|
||||
c.Close()
|
||||
}
|
||||
log.Info("failed to create client for node", zap.Int64("nodeID", n.info.nodeID), zap.Error(err))
|
||||
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID))
|
||||
}
|
||||
clients = append(clients, client)
|
||||
}
|
||||
|
||||
n.initialized.Store(true)
|
||||
n.poolSize = poolSize
|
||||
n.clients = clients
|
||||
return nil
|
||||
}
|
||||
|
||||
// roundRobinSelectClient selects a client in a round-robin manner
|
||||
func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
if n.isClosed {
|
||||
return nil, errClosed
|
||||
}
|
||||
|
@ -152,23 +122,55 @@ func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) {
|
|||
return nextClient, nil
|
||||
}
|
||||
|
||||
// Notice: close client should only be called by shard client manager. and after close, the client must be removed from the manager.
|
||||
// 1. the client hasn't been used for a long time
|
||||
// 2. shard client manager has been closed.
|
||||
func (n *shardClient) Close(force bool) bool {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if force || n.isExpired() {
|
||||
n.close()
|
||||
}
|
||||
|
||||
return n.isClosed
|
||||
}
|
||||
|
||||
func (n *shardClient) isExpired() bool {
|
||||
return time.Now().UnixNano()-n.lastActiveTs.Load() > n.expiredDuration.Nanoseconds()
|
||||
}
|
||||
|
||||
func (n *shardClient) close() {
|
||||
n.isClosed = true
|
||||
|
||||
for _, client := range n.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Warn("close grpc client failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
n.clients = nil
|
||||
}
|
||||
|
||||
// roundRobinSelectClient selects a client in a round-robin manner
|
||||
type shardClientMgr interface {
|
||||
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
|
||||
ReleaseClientRef(nodeID int64)
|
||||
Close()
|
||||
SetClientCreatorFunc(creator queryNodeCreatorFunc)
|
||||
}
|
||||
|
||||
type shardClientMgrImpl struct {
|
||||
clients struct {
|
||||
sync.RWMutex
|
||||
data map[UniqueID]*shardClient
|
||||
}
|
||||
clients *typeutil.ConcurrentMap[UniqueID, *shardClient]
|
||||
clientCreator queryNodeCreatorFunc
|
||||
closeCh chan struct{}
|
||||
|
||||
closeCh chan struct{}
|
||||
purgeInterval time.Duration
|
||||
expiredDuration time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPurgeInterval = 600 * time.Second
|
||||
defaultExpiredDuration = 60 * time.Minute
|
||||
)
|
||||
|
||||
// SessionOpt provides a way to set params in SessionManager
|
||||
type shardClientMgrOpt func(s shardClientMgr)
|
||||
|
||||
|
@ -183,12 +185,11 @@ func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int6
|
|||
// NewShardClientMgr creates a new shardClientMgr
|
||||
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl {
|
||||
s := &shardClientMgrImpl{
|
||||
clients: struct {
|
||||
sync.RWMutex
|
||||
data map[UniqueID]*shardClient
|
||||
}{data: make(map[UniqueID]*shardClient)},
|
||||
clientCreator: defaultQueryNodeClientCreator,
|
||||
closeCh: make(chan struct{}),
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: defaultQueryNodeClientCreator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: defaultPurgeInterval,
|
||||
expiredDuration: defaultExpiredDuration,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(s)
|
||||
|
@ -203,66 +204,44 @@ func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc)
|
|||
}
|
||||
|
||||
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) {
|
||||
c.clients.RLock()
|
||||
client, ok := c.clients.data[info.nodeID]
|
||||
c.clients.RUnlock()
|
||||
|
||||
if !ok {
|
||||
c.clients.Lock()
|
||||
// Check again after acquiring the lock
|
||||
client, ok = c.clients.data[info.nodeID]
|
||||
if !ok {
|
||||
// Create a new client if it doesn't exist
|
||||
newClient, err := newShardClient(info, c.clientCreator)
|
||||
if err != nil {
|
||||
c.clients.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
c.clients.data[info.nodeID] = newClient
|
||||
client = newClient
|
||||
}
|
||||
c.clients.Unlock()
|
||||
}
|
||||
|
||||
client, _ := c.clients.GetOrInsert(info.nodeID, newShardClient(info, c.clientCreator, c.expiredDuration))
|
||||
return client.getClient(ctx)
|
||||
}
|
||||
|
||||
// PurgeClient purges client if it is not used for a long time
|
||||
func (c *shardClientMgrImpl) PurgeClient() {
|
||||
ticker := time.NewTicker(600 * time.Second)
|
||||
ticker := time.NewTicker(c.purgeInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
shardLocations := globalMetaCache.ListShardLocation()
|
||||
c.clients.Lock()
|
||||
for nodeID, client := range c.clients.data {
|
||||
if _, ok := shardLocations[nodeID]; !ok {
|
||||
client.DecRef()
|
||||
delete(c.clients.data, nodeID)
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
if _, ok := shardLocations[key]; !ok {
|
||||
// if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it
|
||||
if value.isExpired() {
|
||||
closed := value.Close(false)
|
||||
if closed {
|
||||
c.clients.Remove(key)
|
||||
log.Info("remove idle node client", zap.Int64("nodeID", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.clients.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) ReleaseClientRef(nodeID int64) {
|
||||
c.clients.RLock()
|
||||
defer c.clients.RUnlock()
|
||||
if client, ok := c.clients.data[nodeID]; ok {
|
||||
client.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// Close release clients
|
||||
func (c *shardClientMgrImpl) Close() {
|
||||
c.clients.Lock()
|
||||
defer c.clients.Unlock()
|
||||
close(c.closeCh)
|
||||
for _, s := range c.clients.data {
|
||||
s.Close()
|
||||
}
|
||||
c.clients.data = make(map[UniqueID]*shardClient)
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
value.Close(true)
|
||||
c.clients.Remove(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,12 +3,15 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestShardClientMgr(t *testing.T) {
|
||||
|
@ -28,10 +31,8 @@ func TestShardClientMgr(t *testing.T) {
|
|||
_, err := mgr.GetClient(ctx, nodeInfo)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mgr.ReleaseClientRef(1)
|
||||
assert.Equal(t, len(mgr.clients.data), 1)
|
||||
mgr.Close()
|
||||
assert.Equal(t, len(mgr.clients.data), 0)
|
||||
assert.Equal(t, mgr.clients.Len(), 0)
|
||||
}
|
||||
|
||||
func TestShardClient(t *testing.T) {
|
||||
|
@ -40,27 +41,120 @@ func TestShardClient(t *testing.T) {
|
|||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
shardClient, err := newShardClient(nodeInfo, creator)
|
||||
assert.Nil(t, err)
|
||||
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
|
||||
assert.Equal(t, len(shardClient.clients), 0)
|
||||
assert.Equal(t, int64(1), shardClient.refCnt.Load())
|
||||
assert.Equal(t, false, shardClient.initialized.Load())
|
||||
assert.Equal(t, false, shardClient.isClosed)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = shardClient.getClient(ctx)
|
||||
_, err := shardClient.getClient(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
|
||||
assert.Equal(t, int64(2), shardClient.refCnt.Load())
|
||||
assert.Equal(t, true, shardClient.initialized.Load())
|
||||
|
||||
shardClient.DecRef()
|
||||
assert.Equal(t, int64(1), shardClient.refCnt.Load())
|
||||
|
||||
shardClient.DecRef()
|
||||
assert.Equal(t, int64(0), shardClient.refCnt.Load())
|
||||
assert.Equal(t, true, shardClient.isClosed)
|
||||
// test close
|
||||
closed := shardClient.Close(false)
|
||||
assert.False(t, closed)
|
||||
closed = shardClient.Close(true)
|
||||
assert.True(t, closed)
|
||||
}
|
||||
|
||||
func TestPurgeClient(t *testing.T) {
|
||||
node := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
|
||||
returnEmptyResult := atomic.NewBool(false)
|
||||
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo {
|
||||
if returnEmptyResult.Load() {
|
||||
return map[int64]nodeInfo{}
|
||||
}
|
||||
return map[int64]nodeInfo{
|
||||
1: node,
|
||||
}
|
||||
})
|
||||
globalMetaCache = cache
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: creator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: 1 * time.Second,
|
||||
expiredDuration: 3 * time.Second,
|
||||
}
|
||||
|
||||
go s.PurgeClient()
|
||||
defer s.Close()
|
||||
_, err := s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
qnClient, ok := s.clients.Get(1)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, qnClient.lastActiveTs.Load() > 0)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// expected client should not been purged before expiredDuration
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
|
||||
|
||||
_, err = s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
time.Sleep(2 * time.Second)
|
||||
// GetClient should refresh lastActiveTs, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// client reach the expiredDuration, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
|
||||
|
||||
returnEmptyResult.Store(true)
|
||||
time.Sleep(2 * time.Second)
|
||||
// remove client from shard location, expected client should be purged
|
||||
assert.Equal(t, s.clients.Len(), 0)
|
||||
}
|
||||
|
||||
func BenchmarkShardClientMgr(b *testing.B) {
|
||||
node := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
cache := NewMockCache(b)
|
||||
cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
|
||||
1: node,
|
||||
}).Maybe()
|
||||
globalMetaCache = cache
|
||||
qn := mocks.NewMockQueryNodeClient(b)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: creator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: 1 * time.Second,
|
||||
expiredDuration: 10 * time.Second,
|
||||
}
|
||||
go s.PurgeClient()
|
||||
defer s.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := s.GetClient(context.Background(), node)
|
||||
assert.Nil(b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -39,7 +39,6 @@ func RoundRobinPolicy(
|
|||
combineErr = merr.Combine(combineErr, err)
|
||||
continue
|
||||
}
|
||||
defer mgr.ReleaseClientRef(target.nodeID)
|
||||
err = query(ctx, target.nodeID, qn, channel)
|
||||
if err != nil {
|
||||
log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err))
|
||||
|
|
|
@ -79,7 +79,6 @@ func TestQueryTask_all(t *testing.T) {
|
|||
}, nil).Maybe()
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
|
|
|
@ -2111,7 +2111,6 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
|||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().ReleaseClientRef(mock.Anything)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
|
|
|
@ -80,7 +80,6 @@ func (s *StatisticTaskSuite) SetupTest() {
|
|||
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
mgr := NewMockShardClientManager(s.T())
|
||||
mgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe()
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
|
||||
s.lb = NewLBPolicyImpl(mgr)
|
||||
|
||||
|
|
Loading…
Reference in New Issue