Add cache of grpc client of ShardLeader in proxy (#17301)

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/17342/merge
zhenshan.cao 2022-06-02 12:16:03 +08:00 committed by GitHub
parent 5698bf4236
commit 2763efc9b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 509 additions and 638 deletions

View File

@ -24,7 +24,8 @@ func TestValidAuth(t *testing.T) {
// normal metadata
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
res = validAuth(ctx, []string{crypto.Base64Encode("mockUser:mockPass")})
assert.True(t, res)
@ -52,7 +53,8 @@ func TestAuthenticationInterceptor(t *testing.T) {
// mock metacache
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err = InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err = InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
// with invalid metadata
md := metadata.Pairs("xxx", "yyy")

View File

@ -2410,10 +2410,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
},
ReqID: Params.ProxyCfg.GetNodeID(),
},
request: request,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
getQueryNodePolicy: defaultGetQueryNodePolicy,
request: request,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
shardMgr: node.shardMgr,
}
travelTs := request.TravelTimestamp
@ -2649,10 +2649,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
},
ReqID: Params.ProxyCfg.GetNodeID(),
},
request: request,
qc: node.queryCoord,
getQueryNodePolicy: defaultGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
request: request,
qc: node.queryCoord,
queryShardPolicy: roundRobinPolicy,
shardMgr: node.shardMgr,
}
method := "Query"
@ -3062,8 +3062,8 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
qc: node.queryCoord,
ids: ids.IdArray,
getQueryNodePolicy: defaultGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
queryShardPolicy: roundRobinPolicy,
shardMgr: node.shardMgr,
}
items := []zapcore.Field{

View File

@ -55,7 +55,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error)
ClearShards(collectionName string)
RemoveCollection(ctx context.Context, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID)
@ -73,7 +73,7 @@ type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo
shardLeaders map[string][]queryNode
shardLeaders map[string][]nodeInfo
leaderMutex sync.Mutex
createdTimestamp uint64
createdUtcTimestamp uint64
@ -82,10 +82,10 @@ type collectionInfo struct {
// CloneShardLeaders returns a copy of shard leaders
// leaderMutex shall be accuired before invoking this method
func (c *collectionInfo) CloneShardLeaders() map[string][]queryNode {
m := make(map[string][]queryNode)
func (c *collectionInfo) CloneShardLeaders() map[string][]nodeInfo {
m := make(map[string][]nodeInfo)
for channel, leaders := range c.shardLeaders {
l := make([]queryNode, len(leaders))
l := make([]nodeInfo, len(leaders))
copy(l, leaders)
m[channel] = l
}
@ -111,15 +111,16 @@ type MetaCache struct {
credUsernameList []string // no need initialize when NewMetaCache
mu sync.RWMutex
credMut sync.RWMutex
shardMgr *shardClientMgr
}
// globalMetaCache is singleton instance of Cache
var globalMetaCache Cache
// InitMetaCache initializes globalMetaCache
func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) error {
func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr *shardClientMgr) error {
var err error
globalMetaCache, err = NewMetaCache(rootCoord, queryCoord)
globalMetaCache, err = NewMetaCache(rootCoord, queryCoord, shardMgr)
if err != nil {
return err
}
@ -127,12 +128,13 @@ func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) error
}
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) (*MetaCache, error) {
func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr *shardClientMgr) (*MetaCache, error) {
return &MetaCache{
rootCoord: rootCoord,
queryCoord: queryCoord,
collInfo: map[string]*collectionInfo{},
credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr,
}, nil
}
@ -583,7 +585,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
}
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error) {
info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil {
return nil, err
@ -601,7 +603,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
}
req := &querypb.GetShardLeadersRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_GetShardLeaders,
@ -615,7 +616,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
childCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
err = retry.Do(childCtx, func() error {
resp, err = qc.GetShardLeaders(ctx, req)
resp, err = m.queryCoord.GetShardLeaders(ctx, req)
if err != nil {
return retry.Unrecoverable(err)
}
@ -629,31 +630,38 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
})
if err != nil {
return nil, fmt.Errorf("GetShardLeaders timeout, error: %s", err.Error())
return nil, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
}
shards := parseShardLeaderList2QueryNode(resp.GetShards())
// manipulate info in map, get map returns a copy of the information
m.mu.RLock()
defer m.mu.RUnlock()
info = m.collInfo[collectionName]
// lock leader
info.leaderMutex.Lock()
defer info.leaderMutex.Unlock()
oldShards := info.shardLeaders
info.shardLeaders = shards
info.leaderMutex.Unlock()
m.mu.RUnlock()
return info.CloneShardLeaders(), nil
// update refcnt in shardClientMgr
ret := info.CloneShardLeaders()
_ = m.shardMgr.UpdateShardLeaders(oldShards, ret)
return ret, nil
}
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
shard2QueryNodes := make(map[string][]queryNode)
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo {
shard2QueryNodes := make(map[string][]nodeInfo)
for _, leaders := range shardsLeaders {
qns := make([]queryNode, len(leaders.GetNodeIds()))
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns
@ -666,12 +674,14 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
func (m *MetaCache) ClearShards(collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.collInfo[collectionName]
if !ok {
return
//var ret map[string][]nodeInfo
info, ok := m.collInfo[collectionName]
if ok {
m.collInfo[collectionName].shardLeaders = nil
}
m.mu.Unlock()
// delete refcnt in shardClientMgr
if ok {
_ = m.shardMgr.UpdateShardLeaders(info.shardLeaders, nil)
}
m.collInfo[collectionName].shardLeaders = nil
}

View File

@ -198,7 +198,8 @@ func TestMetaCache_GetCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, "collection1")
@ -245,7 +246,8 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
rootCoord.Error = true
@ -275,7 +277,8 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, "collection3")
@ -290,7 +293,8 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
id, err := globalMetaCache.GetPartitionID(ctx, "collection1", "par1")
@ -311,7 +315,8 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
// Test the case where ShowPartitionsResponse is not aligned
@ -340,35 +345,35 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
func TestMetaCache_GetShards(t *testing.T) {
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
qc := NewQueryCoordMock()
shardMgr := newShardClientMgr()
err := InitMetaCache(rootCoord, qc, shardMgr)
require.Nil(t, err)
var (
ctx = context.TODO()
collectionName = "collection1"
qc = NewQueryCoordMock()
)
qc.Init()
qc.Start()
defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists", qc)
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists")
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
qc.validShardLeaders = false
shards, err := globalMetaCache.GetShards(ctx, false, collectionName, qc)
shards, err := globalMetaCache.GetShards(ctx, false, collectionName)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info", func(t *testing.T) {
qc.validShardLeaders = true
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
@ -377,7 +382,7 @@ func TestMetaCache_GetShards(t *testing.T) {
// get from cache
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
@ -387,14 +392,14 @@ func TestMetaCache_GetShards(t *testing.T) {
func TestMetaCache_ClearShards(t *testing.T) {
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
qc := NewQueryCoordMock()
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, qc, mgr)
require.Nil(t, err)
var (
ctx = context.TODO()
collectionName = "collection1"
qc = NewQueryCoordMock()
)
qc.Init()
qc.Start()
@ -411,7 +416,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
t.Run("Clear valid collection valid cache", func(t *testing.T) {
qc.validShardLeaders = true
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards))
@ -420,7 +425,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
globalMetaCache.ClearShards(collectionName)
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -431,7 +436,8 @@ func TestMetaCache_LoadCache(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
mgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, mgr)
assert.Nil(t, err)
t.Run("test IsCollectionLoaded", func(t *testing.T) {
@ -474,7 +480,8 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
err := InitMetaCache(rootCoord, queryCoord)
shardMgr := newShardClientMgr()
err := InitMetaCache(rootCoord, queryCoord, shardMgr)
assert.Nil(t, err)
info, err := globalMetaCache.GetCollectionInfo(ctx, "collection1")

View File

@ -89,7 +89,8 @@ type Proxy struct {
metricsCacheManager *metricsinfo.MetricsCacheManager
session *sessionutil.Session
session *sessionutil.Session
shardMgr *shardClientMgr
factory dependency.Factory
@ -110,6 +111,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
cancel: cancel,
factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: newShardClientMgr(),
}
node.UpdateStateCode(internalpb.StateCode_Abnormal)
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
@ -219,7 +221,7 @@ func (node *Proxy) Init() error {
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
log.Debug("init meta cache", zap.String("role", typeutil.ProxyRole))
if err := InitMetaCache(node.rootCoord, node.queryCoord); err != nil {
if err := InitMetaCache(node.rootCoord, node.queryCoord, node.shardMgr); err != nil {
log.Warn("failed to init meta cache", zap.Error(err), zap.String("role", typeutil.ProxyRole))
return err
}
@ -378,6 +380,10 @@ func (node *Proxy) Stop() error {
node.session.Revoke(time.Second)
if node.shardMgr != nil {
node.shardMgr.Close()
}
// https://github.com/milvus-io/milvus/issues/12282
node.UpdateStateCode(internalpb.StateCode_Abnormal)

View File

@ -3036,7 +3036,8 @@ func TestProxy_Import(t *testing.T) {
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
err := InitMetaCache(rc, qc)
shardMgr := newShardClientMgr()
err := InitMetaCache(rc, qc, shardMgr)
assert.NoError(t, err)
rc.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{

View File

@ -0,0 +1,206 @@
package proxy
import (
"context"
"errors"
"fmt"
"sync"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/types"
)
type queryNodeCreatorFunc func(ctx context.Context, addr string) (types.QueryNode, error)
type nodeInfo struct {
nodeID UniqueID
address string
}
func (n nodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d>", n.nodeID)
}
var errClosed = errors.New("client is closed")
type shardClient struct {
sync.RWMutex
info nodeInfo
client types.QueryNode
isClosed bool
refCnt int
}
func (n *shardClient) getClient(ctx context.Context) (types.QueryNode, error) {
n.RLock()
defer n.RUnlock()
if n.isClosed {
return nil, errClosed
}
return n.client, nil
}
func (n *shardClient) inc() {
n.Lock()
defer n.Unlock()
if n.isClosed {
return
}
n.refCnt++
}
func (n *shardClient) close() {
n.isClosed = true
n.refCnt = 0
if n.client != nil {
n.client.Stop()
n.client = nil
}
}
func (n *shardClient) dec() bool {
n.Lock()
defer n.Unlock()
if n.isClosed {
return true
}
if n.refCnt > 0 {
n.refCnt--
}
if n.refCnt == 0 {
n.close()
}
return n.refCnt == 0
}
func (n *shardClient) Close() {
n.Lock()
defer n.Unlock()
n.close()
}
func newShardClient(info *nodeInfo, client types.QueryNode) *shardClient {
ret := &shardClient{
info: nodeInfo{
nodeID: info.nodeID,
address: info.address,
},
client: client,
refCnt: 1,
}
return ret
}
type shardClientMgr struct {
clients struct {
sync.RWMutex
data map[UniqueID]*shardClient
}
clientCreator queryNodeCreatorFunc
}
// SessionOpt provides a way to set params in SessionManager
type shardClientMgrOpt func(s *shardClientMgr)
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s *shardClientMgr) { s.clientCreator = creator }
}
func defaultShardClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return qnClient.NewClient(ctx, addr)
}
// NewShardClientMgr creates a new shardClientMgr
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgr {
s := &shardClientMgr{
clients: struct {
sync.RWMutex
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultShardClientCreator,
}
for _, opt := range options {
opt(s)
}
return s
}
// Warning this method may modify parameter `oldLeaders`
func (c *shardClientMgr) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error {
oldLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range oldLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
oldLocalMap[n.nodeID] = n
}
}
}
newLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range newLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
_, ok2 := newLocalMap[n.nodeID]
if !ok2 {
newLocalMap[n.nodeID] = n
}
}
delete(oldLocalMap, n.nodeID)
}
}
c.clients.Lock()
defer c.clients.Unlock()
for _, node := range newLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok {
client.inc()
} else {
// context.Background() is useless
// TODO QueryNode NewClient remove ctx parameter
// TODO Remove Init && Start interface in QueryNode client
if c.clientCreator == nil {
return fmt.Errorf("clientCreator function is nil")
}
shardClient, err := c.clientCreator(context.Background(), node.address)
if err != nil {
return err
}
client := newShardClient(node, shardClient)
c.clients.data[node.nodeID] = client
}
}
for _, node := range oldLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok && client.dec() {
delete(c.clients.data, node.nodeID)
}
}
return nil
}
func (c *shardClientMgr) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNode, error) {
c.clients.RLock()
client, ok := c.clients.data[nodeID]
c.clients.RUnlock()
if !ok {
return nil, fmt.Errorf("can not find client of node %d", nodeID)
}
return client.getClient(ctx)
}
// Close release clients
func (c *shardClientMgr) Close() {
c.clients.Lock()
defer c.clients.Unlock()
for _, s := range c.clients.data {
s.Close()
}
c.clients.data = make(map[UniqueID]*shardClient)
}

View File

@ -0,0 +1,100 @@
package proxy
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/mock"
"github.com/stretchr/testify/assert"
)
func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo {
leaders := make(map[string][]nodeInfo)
nodeInfos := make([]nodeInfo, len(leaderIDs))
for i, id := range leaderIDs {
nodeInfos[i] = nodeInfo{
nodeID: id,
address: "fake",
}
}
leaders[channel] = nodeInfos
return leaders
}
func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) {
mgr := newShardClientMgr(withShardClientCreator(nil))
mgr.clientCreator = nil
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err := mgr.UpdateShardLeaders(nil, leaders)
assert.Error(t, err)
}
func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) {
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
return &mock.QueryNodeClient{}, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
_, err := mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
err = mgr.UpdateShardLeaders(nil, nil)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err = mgr.UpdateShardLeaders(leaders, nil)
assert.NoError(t, err)
}
func TestShardClientMgr_UpdateShardLeaders_NonEmpty(t *testing.T) {
mgr := newShardClientMgr()
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
err := mgr.UpdateShardLeaders(nil, leaders)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
newLeaders := genShardLeaderInfo("c1", []UniqueID{2, 3})
err = mgr.UpdateShardLeaders(leaders, newLeaders)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
}
func TestShardClientMgr_UpdateShardLeaders_Ref(t *testing.T) {
mgr := newShardClientMgr()
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
for i := 0; i < 2; i++ {
err := mgr.UpdateShardLeaders(nil, leaders)
assert.NoError(t, err)
}
partLeaders := genShardLeaderInfo("c1", []UniqueID{1})
_, err := mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
err = mgr.UpdateShardLeaders(partLeaders, nil)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.NoError(t, err)
err = mgr.UpdateShardLeaders(partLeaders, nil)
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(1))
assert.Error(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(2))
assert.NoError(t, err)
_, err = mgr.GetClient(context.Background(), UniqueID(3))
assert.NoError(t, err)
}

View File

@ -70,8 +70,9 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
collectionID: collectionID,
}
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(rootCoord, queryCoord)
_ = InitMetaCache(rootCoord, queryCoord, shardMgr)
assert.Error(t, gist.Execute(ctx))
rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{

View File

@ -5,51 +5,20 @@ import (
"errors"
"fmt"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
)
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
// TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
func defaultGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
qn, err := qnClient.NewClient(ctx, address)
if err != nil {
return nil, err
}
if err := qn.Init(); err != nil {
return nil, err
}
if err := qn.Start(); err != nil {
return nil, err
}
return qn, nil
}
type pickShardPolicy func(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error
var (
errBegin = errors.New("begin error")
errInvalidShardLeaders = errors.New("Invalid shard leader")
)
type queryNode struct {
nodeID UniqueID
address string
}
func (q queryNode) String() string {
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
}
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) {
func updateShardsWithRoundRobin(shardsLeaders map[string][]nodeInfo) {
for channelID, leaders := range shardsLeaders {
if len(leaders) <= 1 {
continue
@ -59,7 +28,7 @@ func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) {
}
}
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
func roundRobinPolicy(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error {
var (
err = errBegin
current = 0
@ -75,7 +44,7 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
zap.Int64("nodeID", currentID))
}
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
qn, err = mgr.GetClient(ctx, leaders[current].nodeID)
if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err))
@ -83,7 +52,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
continue
}
defer qn.Stop()
err = query(currentID, qn)
if err != nil {
log.Warn("fail to Query with shard leader",

View File

@ -11,11 +11,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/util/mock"
"go.uber.org/zap"
)
func TestUpdateShardsWithRoundRobin(t *testing.T) {
list := map[string][]queryNode{
list := map[string][]nodeInfo{
"channel-1": {
{1, "addr1"},
{2, "addr2"},
@ -34,7 +36,7 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
assert.Equal(t, "addr21", list["channel-2"][0].address)
t.Run("check print", func(t *testing.T) {
qns := []queryNode{
qns := []nodeInfo{
{1, "addr1"},
{2, "addr2"},
{20, "addr20"},
@ -54,10 +56,16 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
func TestRoundRobinPolicy(t *testing.T) {
var (
getQueryNodePolicy = mockGetQueryNodePolicy
ctx = context.TODO()
ctx = context.TODO()
)
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
return &mock.QueryNodeClient{}, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
dummyLeaders := genShardLeaderInfo("c1", []UniqueID{-1, 1, 2, 3})
mgr.UpdateShardLeaders(nil, dummyLeaders)
t.Run("All fails", func(t *testing.T) {
allFailTests := []struct {
leaderIDs []UniqueID
@ -73,13 +81,13 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Run(test.description, func(t *testing.T) {
query := (&mockQuery{isvalid: false}).query
leaders := make([]queryNode, 0, len(test.leaderIDs))
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
leaders = append(leaders, nodeInfo{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
err := roundRobinPolicy(ctx, mgr, query, leaders)
require.Error(t, err)
})
}
@ -98,12 +106,12 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range allPassTests {
query := (&mockQuery{isvalid: true}).query
leaders := make([]queryNode, 0, len(test.leaderIDs))
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
leaders = append(leaders, nodeInfo{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
err := roundRobinPolicy(ctx, mgr, query, leaders)
require.NoError(t, err)
}
})
@ -120,18 +128,18 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range passAtLast {
query := (&mockQuery{isvalid: true}).query
leaders := make([]queryNode, 0, len(test.leaderIDs))
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
leaders = append(leaders, nodeInfo{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
err := roundRobinPolicy(ctx, mgr, query, leaders)
require.NoError(t, err)
}
})
}
func mockGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
func mockQueryNodeCreator(ctx context.Context, address string) (types.QueryNode, error) {
return &QueryNodeMock{address: address}, nil
}

View File

@ -50,15 +50,11 @@ type queryTask struct {
runningGroup *errgroup.Group
runningGroupCtx context.Context
getQueryNodePolicy getQueryNodePolicy
queryShardPolicy pickShardPolicy
queryShardPolicy pickShardPolicy
shardMgr *shardClientMgr
}
func (t *queryTask) PreExecute(ctx context.Context) error {
if t.getQueryNodePolicy == nil {
t.getQueryNodePolicy = defaultGetQueryNodePolicy
}
if t.queryShardPolicy == nil {
t.queryShardPolicy = roundRobinPolicy
}
@ -240,11 +236,10 @@ func (t *queryTask) Execute(ctx context.Context) error {
defer tr.Elapse("done")
executeQuery := func(withCache bool) error {
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
if err != nil {
return err
}
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
@ -353,7 +348,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
return nil
}
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
func (t *queryTask) queryShard(ctx context.Context, leaders []nodeInfo, channelID string) error {
query := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.QueryRequest{
Req: t.RetrieveRequest,
@ -378,7 +373,7 @@ func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channel
return nil
}
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
err := t.queryShardPolicy(t.TraceCtx(), t.shardMgr, query, leaders)
if err != nil {
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
return err

View File

@ -11,13 +11,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
@ -41,16 +40,18 @@ func TestQueryTask_all(t *testing.T) {
hitNum = 10
)
mockGetQueryNodePolicy := func(ctx context.Context, address string) (types.QueryNode, error) {
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
return qn, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
rc.Start()
defer rc.Stop()
qc.Start()
defer qc.Stop()
err = InitMetaCache(rc, qc)
err = InitMetaCache(rc, qc, mgr)
assert.NoError(t, err)
fieldName2Types := map[string]schemapb.DataType{
@ -125,8 +126,8 @@ func TestQueryTask_all(t *testing.T) {
},
qc: qc,
getQueryNodePolicy: mockGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
queryShardPolicy: roundRobinPolicy,
shardMgr: mgr,
}
for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
@ -181,7 +182,8 @@ func TestCheckIfLoaded(t *testing.T) {
err = rc.Start()
defer rc.Stop()
require.NoError(t, err)
err = InitMetaCache(rc, qc)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()

View File

@ -50,18 +50,14 @@ type searchTask struct {
runningGroup *errgroup.Group
runningGroupCtx context.Context
getQueryNodePolicy getQueryNodePolicy
searchShardPolicy pickShardPolicy
searchShardPolicy pickShardPolicy
shardMgr *shardClientMgr
}
func (t *searchTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PreExecute")
defer sp.Finish()
if t.getQueryNodePolicy == nil {
t.getQueryNodePolicy = defaultGetQueryNodePolicy
}
if t.searchShardPolicy == nil {
t.searchShardPolicy = roundRobinPolicy
}
@ -278,11 +274,10 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer tr.Elapse("done")
executeSearch := func(withCache bool) error {
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
if err != nil {
return err
}
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
@ -299,28 +294,23 @@ func (t *searchTask) Execute(ctx context.Context) error {
zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
err := t.searchShard(t.runningGroupCtx, leaders, channelID)
if err != nil {
return err
}
return nil
return t.searchShard(t.runningGroupCtx, leaders, channelID)
})
}
err = t.runningGroup.Wait()
return err
}
err := executeSearch(WithCache)
if err == errInvalidShardLeaders {
log.Warn("invalid shard leaders from cache, updating shardleader caches and retry search")
if err == errInvalidShardLeaders || funcutil.IsGrpcErr(err) {
log.Warn("first search failed, updating shardleader caches and retry search", zap.Error(err))
return executeSearch(WithoutCache)
}
if err != nil {
return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error())
return fmt.Errorf("fail to search on all shard leaders, err=%w", err)
}
log.Info("Search Execute done.",
log.Debug("Search Execute done.",
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "search"))
return nil
}
@ -415,7 +405,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return nil
}
func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
func (t *searchTask) searchShard(ctx context.Context, leaders []nodeInfo, channelID string) error {
search := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.SearchRequest{
@ -423,9 +413,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, chann
DmlChannel: channelID,
Scope: querypb.DataScope_All,
}
result, err := qn.Search(ctx, req)
if err != nil || result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
if err != nil {
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode search returns error", zap.Int64("nodeID", nodeID),
zap.Error(err))
return errInvalidShardLeaders
@ -435,12 +427,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, chann
zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
}
t.resultBuf <- result
return nil
}
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
err := t.searchShardPolicy(t.TraceCtx(), t.shardMgr, search, leaders)
if err != nil {
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
return err

View File

@ -124,7 +124,8 @@ func TestSearchTask_PreExecute(t *testing.T) {
err = rc.Start()
defer rc.Stop()
require.NoError(t, err)
err = InitMetaCache(rc, qc)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
@ -423,7 +424,8 @@ func TestSearchTaskV2_Execute(t *testing.T) {
err = rc.Start()
require.NoError(t, err)
defer rc.Stop()
err = InitMetaCache(rc, qc)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()

View File

@ -1094,7 +1094,8 @@ func TestDropCollectionTask(t *testing.T) {
qc.Start()
defer qc.Stop()
ctx := context.Background()
InitMetaCache(rc, qc)
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
master := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
@ -1181,7 +1182,8 @@ func TestHasCollectionTask(t *testing.T) {
qc.Start()
defer qc.Stop()
ctx := context.Background()
InitMetaCache(rc, qc)
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
prefix := "TestHasCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1266,7 +1268,8 @@ func TestDescribeCollectionTask(t *testing.T) {
qc.Start()
defer qc.Stop()
ctx := context.Background()
InitMetaCache(rc, qc)
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1328,7 +1331,8 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
qc.Start()
defer qc.Stop()
ctx := context.Background()
InitMetaCache(rc, qc)
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1392,7 +1396,8 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
qc.Start()
defer qc.Stop()
ctx := context.Background()
InitMetaCache(rc, qc)
mgr := newShardClientMgr()
InitMetaCache(rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1658,7 +1663,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
ctx := context.Background()
err = InitMetaCache(rc, qc)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
assert.NoError(t, err)
shardsNum := int32(2)
@ -1911,7 +1917,8 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
ctx := context.Background()
err = InitMetaCache(rc, qc)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
assert.NoError(t, err)
shardsNum := int32(2)

View File

@ -1,82 +0,0 @@
package querynode
import (
"context"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/types"
)
type mockProxy struct {
}
func (m *mockProxy) Init() error {
return nil
}
func (m *mockProxy) Start() error {
return nil
}
func (m *mockProxy) Stop() error {
return nil
}
func (m *mockProxy) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
return nil, nil
}
func (m *mockProxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func (m *mockProxy) Register() error {
return nil
}
func (m *mockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *mockProxy) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *mockProxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (m *mockProxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (m *mockProxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *mockProxy) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *mockProxy) ClearCredUsersCache(ctx context.Context, request *internalpb.ClearCredUsersCacheRequest) (*commonpb.Status, error) {
return nil, nil
}
func newMockProxy() types.Proxy {
return &mockProxy{}
}
func mockProxyCreator() proxyCreatorFunc {
return func(ctx context.Context, addr string) (types.Proxy, error) {
return newMockProxy(), nil
}
}

View File

@ -1,75 +0,0 @@
package querynode
import (
"context"
"errors"
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/types"
)
var errDisposed = errors.New("client is disposed")
type NodeInfo struct {
NodeID int64
Address string
}
type proxyCreatorFunc func(ctx context.Context, addr string) (types.Proxy, error)
type Session struct {
sync.Mutex
info *NodeInfo
client types.Proxy
clientCreator proxyCreatorFunc
isDisposed bool
}
func NewSession(info *NodeInfo, creator proxyCreatorFunc) *Session {
return &Session{
info: info,
clientCreator: creator,
}
}
func (n *Session) GetOrCreateClient(ctx context.Context) (types.Proxy, error) {
n.Lock()
defer n.Unlock()
if n.isDisposed {
return nil, errDisposed
}
if n.client != nil {
return n.client, nil
}
if n.clientCreator == nil {
return nil, fmt.Errorf("unable to create client for %s because of a nil client creator", n.info.Address)
}
err := n.initClient(ctx)
return n.client, err
}
func (n *Session) initClient(ctx context.Context) (err error) {
if n.client, err = n.clientCreator(ctx, n.info.Address); err != nil {
return
}
if err = n.client.Init(); err != nil {
return
}
return n.client.Start()
}
func (n *Session) Dispose() {
n.Lock()
defer n.Unlock()
if n.client != nil {
n.client.Stop()
n.client = nil
}
n.isDisposed = true
}

View File

@ -1,152 +0,0 @@
package querynode
import (
"context"
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/util/funcutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/types"
)
type SessionManager struct {
sessions struct {
sync.RWMutex
data map[int64]*Session
}
// sessions sync.Map // UniqueID -> Session
sessionCreator proxyCreatorFunc
}
// SessionOpt provides a way to set params in SessionManager
type SessionOpt func(c *SessionManager)
func withSessionCreator(creator proxyCreatorFunc) SessionOpt {
return func(c *SessionManager) { c.sessionCreator = creator }
}
func defaultSessionCreator() proxyCreatorFunc {
return func(ctx context.Context, addr string) (types.Proxy, error) {
return grpcproxyclient.NewClient(ctx, addr)
}
}
// NewSessionManager creates a new SessionManager
func NewSessionManager(options ...SessionOpt) *SessionManager {
m := &SessionManager{
sessions: struct {
sync.RWMutex
data map[int64]*Session
}{data: make(map[int64]*Session)},
sessionCreator: defaultSessionCreator(),
}
for _, opt := range options {
opt(m)
}
return m
}
// AddSession creates a new session
func (c *SessionManager) AddSession(node *NodeInfo) {
c.sessions.Lock()
defer c.sessions.Unlock()
log.Info("add proxy session", zap.Int64("node", node.NodeID))
session := NewSession(node, c.sessionCreator)
c.sessions.data[node.NodeID] = session
}
func (c *SessionManager) Startup(nodes []*NodeInfo) {
for _, node := range nodes {
c.AddSession(node)
}
}
// DeleteSession removes the node session
func (c *SessionManager) DeleteSession(node *NodeInfo) {
c.sessions.Lock()
defer c.sessions.Unlock()
log.Info("delete proxy session", zap.Int64("node", node.NodeID))
if session, ok := c.sessions.data[node.NodeID]; ok {
session.Dispose()
delete(c.sessions.data, node.NodeID)
}
}
// GetSessions gets all node sessions
func (c *SessionManager) GetSessions() []*Session {
c.sessions.RLock()
defer c.sessions.RUnlock()
ret := make([]*Session, 0, len(c.sessions.data))
for _, s := range c.sessions.data {
ret = append(ret, s)
}
return ret
}
func (c *SessionManager) SendSearchResult(ctx context.Context, nodeID UniqueID, result *internalpb.SearchResults) error {
cli, err := c.getClient(ctx, nodeID)
if err != nil {
log.Warn("failed to send search result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
return err
}
resp, err := cli.SendSearchResult(ctx, result)
if err := funcutil.VerifyResponse(resp, err); err != nil {
log.Warn("failed to send search result", zap.Int64("node", nodeID), zap.Error(err))
return err
}
log.Debug("success to send search result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
return nil
}
func (c *SessionManager) SendRetrieveResult(ctx context.Context, nodeID UniqueID, result *internalpb.RetrieveResults) error {
cli, err := c.getClient(ctx, nodeID)
if err != nil {
log.Warn("failed to send retrieve result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
return err
}
resp, err := cli.SendRetrieveResult(ctx, result)
if err := funcutil.VerifyResponse(resp, err); err != nil {
log.Warn("failed to send retrieve result", zap.Int64("node", nodeID), zap.Error(err))
return err
}
log.Debug("success to send retrieve result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
return nil
}
func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.Proxy, error) {
c.sessions.RLock()
session, ok := c.sessions.data[nodeID]
c.sessions.RUnlock()
if !ok {
return nil, fmt.Errorf("can not find session of node %d", nodeID)
}
return session.GetOrCreateClient(ctx)
}
// Close release sessions
func (c *SessionManager) Close() {
c.sessions.Lock()
defer c.sessions.Unlock()
for _, s := range c.sessions.data {
s.Dispose()
}
c.sessions.data = nil
}

View File

@ -111,9 +111,8 @@ type QueryNode struct {
factory dependency.Factory
scheduler *taskScheduler
session *sessionutil.Session
eventCh <-chan *sessionutil.SessionEvent
sessionManager *SessionManager
session *sessionutil.Session
eventCh <-chan *sessionutil.SessionEvent
vectorStorage storage.ChunkManager
cacheStorage storage.ChunkManager
@ -203,74 +202,6 @@ func (node *QueryNode) InitSegcore() {
C.SegcoreSetIndexSliceSize(cIndexSliceSize)
}
func (node *QueryNode) initServiceDiscovery() error {
if node.session == nil {
return errors.New("session is nil")
}
sessions, rev, err := node.session.GetSessions(typeutil.ProxyRole)
if err != nil {
log.Warn("QueryNode failed to init service discovery", zap.Error(err))
return err
}
log.Info("QueryNode success to get Proxy sessions", zap.Any("sessions", sessions))
nodes := make([]*NodeInfo, 0, len(sessions))
for _, session := range sessions {
info := &NodeInfo{
NodeID: session.ServerID,
Address: session.Address,
}
nodes = append(nodes, info)
}
node.sessionManager.Startup(nodes)
node.eventCh = node.session.WatchServices(typeutil.ProxyRole, rev+1, nil)
return nil
}
func (node *QueryNode) watchService(ctx context.Context) {
defer node.wg.Done()
for {
select {
case <-ctx.Done():
log.Info("watch service shutdown")
return
case event, ok := <-node.eventCh:
if !ok {
// ErrCompacted is handled inside SessionWatcher
log.Error("Session Watcher channel closed", zap.Int64("server id", node.session.ServerID))
// need to call stop in separate goroutine
go node.Stop()
if node.session.TriggerKill {
if p, err := os.FindProcess(os.Getpid()); err == nil {
p.Signal(syscall.SIGINT)
}
}
return
}
node.handleSessionEvent(ctx, event)
}
}
}
func (node *QueryNode) handleSessionEvent(ctx context.Context, event *sessionutil.SessionEvent) {
info := &NodeInfo{
NodeID: event.Session.ServerID,
Address: event.Session.Address,
}
switch event.EventType {
case sessionutil.SessionAddEvent:
node.sessionManager.AddSession(info)
case sessionutil.SessionDelEvent:
node.sessionManager.DeleteSession(info)
default:
log.Warn("receive unknown service event type",
zap.Any("type", event.EventType))
}
}
// Init function init historical and streaming module to manage segments
func (node *QueryNode) Init() error {
var initError error = nil
@ -315,9 +246,6 @@ func (node *QueryNode) Init() error {
node.InitSegcore()
// TODO: add session creator to node
node.sessionManager = NewSessionManager(withSessionCreator(defaultSessionCreator()))
log.Info("query node init successfully",
zap.Any("queryNodeID", Params.QueryNodeCfg.GetNodeID()),
zap.Any("IP", Params.QueryNodeCfg.QueryNodeIP),
@ -337,14 +265,6 @@ func (node *QueryNode) Start() error {
go node.watchChangeInfo()
//go node.statsService.start()
// watch proxy
if err := node.initServiceDiscovery(); err != nil {
return err
}
node.wg.Add(1)
go node.watchService(node.queryNodeLoopCtx)
// create shardClusterService for shardLeader functions.
node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node)
// create shard-level query service

View File

@ -22,10 +22,8 @@ import (
"math/rand"
"net/url"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"testing"
"time"
@ -40,7 +38,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/sessionutil"
)
var embedetcdServer *embed.Etcd
@ -331,83 +328,6 @@ func TestQueryNode_watchChangeInfo(t *testing.T) {
wg.Wait()
}
func TestQueryNode_watchService(t *testing.T) {
t.Run("watch channel closed", func(t *testing.T) {
ech := make(chan *sessionutil.SessionEvent)
qn := &QueryNode{
session: &sessionutil.Session{
TriggerKill: true,
ServerID: 0,
},
wg: sync.WaitGroup{},
eventCh: ech,
queryNodeLoopCancel: func() {},
}
flag := false
closed := false
sigDone := make(chan struct{}, 1)
sigQuit := make(chan struct{}, 1)
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT)
defer signal.Reset(syscall.SIGINT)
qn.wg.Add(1)
go func() {
qn.watchService(context.Background())
flag = true
sigDone <- struct{}{}
}()
go func() {
<-sc
closed = true
sigQuit <- struct{}{}
}()
close(ech)
<-sigDone
<-sigQuit
assert.True(t, flag)
assert.True(t, closed)
})
t.Run("context done", func(t *testing.T) {
ech := make(chan *sessionutil.SessionEvent)
qn := &QueryNode{
session: &sessionutil.Session{
TriggerKill: true,
ServerID: 0,
},
wg: sync.WaitGroup{},
eventCh: ech,
}
flag := false
sigDone := make(chan struct{}, 1)
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT)
defer signal.Reset(syscall.SIGINT)
qn.wg.Add(1)
ctx, cancel := context.WithCancel(context.Background())
go func() {
qn.watchService(ctx)
flag = true
sigDone <- struct{}{}
}()
assert.False(t, flag)
cancel()
<-sigDone
assert.True(t, flag)
})
}
func TestQueryNode_validateChangeChannel(t *testing.T) {
type testCase struct {

View File

@ -39,6 +39,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry"
grpcStatus "google.golang.org/grpc/status"
)
// CheckGrpcReady wait for context timeout, or wait 100ms then send nil to targetCh
@ -352,3 +354,11 @@ func ReadBinary(endian binary.ByteOrder, bs []byte, receiver interface{}) error
buf := bytes.NewReader(bs)
return binary.Read(buf, endian, receiver)
}
func IsGrpcErr(err error) bool {
if err == nil {
return false
}
_, ok := grpcStatus.FromError(err)
return ok
}

View File

@ -32,6 +32,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/stretchr/testify/assert"
grpcCodes "google.golang.org/grpc/codes"
grpcStatus "google.golang.org/grpc/status"
)
type MockComponent struct {
@ -463,3 +465,28 @@ func Test_ReadBinary(t *testing.T) {
assert.NoError(t, ReadBinary(endian, bs, &fs))
assert.ElementsMatch(t, []float32{0, 0}, fs)
}
func TestIsGrpcErr(t *testing.T) {
var err1 error
assert.False(t, IsGrpcErr(err1))
err1 = errors.New("error")
assert.False(t, IsGrpcErr(err1))
bgCtx := context.Background()
ctx1, cancel1 := context.WithCancel(bgCtx)
cancel1()
assert.False(t, IsGrpcErr(ctx1.Err()))
timeout := 20 * time.Millisecond
ctx1, cancel1 = context.WithTimeout(bgCtx, timeout)
time.Sleep(timeout * 2)
assert.False(t, IsGrpcErr(ctx1.Err()))
cancel1()
err1 = grpcStatus.Error(grpcCodes.Canceled, "test")
assert.True(t, IsGrpcErr(err1))
err1 = grpcStatus.Error(grpcCodes.Unavailable, "test")
assert.True(t, IsGrpcErr(err1))
}

View File

@ -203,17 +203,14 @@ func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{
return ret, nil
}
// status.Error(codes.Canceled, context.Canceled.Error()
// if err2 == context.Canceled || err2 == context.DeadlineExceeded {
// return nil, err2
// }
if !funcutil.CheckCtxValid(ctx) {
return nil, err2
}
if !funcutil.IsGrpcErr(err2) {
log.Debug("ClientBase:isNotGrpcErr", zap.Error(err2))
return nil, err2
}
log.Debug(c.GetRole()+" ClientBase grpc error, start to reset connection", zap.Error(err2))
c.resetConnection(client)
return ret, err2
}