milvus/internal/proxy/shard_client_test.go

161 lines
4.3 KiB
Go

package proxy
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestShardClientMgr(t *testing.T) {
ctx := context.Background()
nodeInfo := nodeInfo{
nodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mgr := newShardClientMgr()
mgr.SetClientCreatorFunc(creator)
_, err := mgr.GetClient(ctx, nodeInfo)
assert.Nil(t, err)
mgr.Close()
assert.Equal(t, mgr.clients.Len(), 0)
}
func TestShardClient(t *testing.T) {
nodeInfo := nodeInfo{
nodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
assert.Equal(t, len(shardClient.clients), 0)
assert.Equal(t, false, shardClient.initialized.Load())
assert.Equal(t, false, shardClient.isClosed)
ctx := context.Background()
_, err := shardClient.getClient(ctx)
assert.Nil(t, err)
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
// test close
closed := shardClient.Close(false)
assert.False(t, closed)
closed = shardClient.Close(true)
assert.True(t, closed)
}
func TestPurgeClient(t *testing.T) {
node := nodeInfo{
nodeID: 1,
}
returnEmptyResult := atomic.NewBool(false)
cache := NewMockCache(t)
cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo {
if returnEmptyResult.Load() {
return map[int64]nodeInfo{}
}
return map[int64]nodeInfo{
1: node,
}
})
globalMetaCache = cache
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
expiredDuration: 3 * time.Second,
}
go s.PurgeClient()
defer s.Close()
_, err := s.GetClient(context.Background(), node)
assert.Nil(t, err)
qnClient, ok := s.clients.Get(1)
assert.True(t, ok)
assert.True(t, qnClient.lastActiveTs.Load() > 0)
time.Sleep(2 * time.Second)
// expected client should not been purged before expiredDuration
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
_, err = s.GetClient(context.Background(), node)
assert.Nil(t, err)
time.Sleep(2 * time.Second)
// GetClient should refresh lastActiveTs, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
time.Sleep(2 * time.Second)
// client reach the expiredDuration, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
returnEmptyResult.Store(true)
time.Sleep(2 * time.Second)
// remove client from shard location, expected client should be purged
assert.Equal(t, s.clients.Len(), 0)
}
func BenchmarkShardClientMgr(b *testing.B) {
node := nodeInfo{
nodeID: 1,
}
cache := NewMockCache(b)
cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
1: node,
}).Maybe()
globalMetaCache = cache
qn := mocks.NewMockQueryNodeClient(b)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
expiredDuration: 10 * time.Second,
}
go s.PurgeClient()
defer s.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := s.GetClient(context.Background(), node)
assert.Nil(b, err)
}
})
}