mirror of https://github.com/milvus-io/milvus.git
Add cache of grpc client of ShardLeader in proxy (#17301)
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/17342/merge
parent
5698bf4236
commit
2763efc9b0
|
@ -24,7 +24,8 @@ func TestValidAuth(t *testing.T) {
|
|||
// normal metadata
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
res = validAuth(ctx, []string{crypto.Base64Encode("mockUser:mockPass")})
|
||||
assert.True(t, res)
|
||||
|
@ -52,7 +53,8 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
|||
// mock metacache
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err = InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
// with invalid metadata
|
||||
md := metadata.Pairs("xxx", "yyy")
|
||||
|
|
|
@ -2410,10 +2410,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
},
|
||||
ReqID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
getQueryNodePolicy: defaultGetQueryNodePolicy,
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
shardMgr: node.shardMgr,
|
||||
}
|
||||
|
||||
travelTs := request.TravelTimestamp
|
||||
|
@ -2649,10 +2649,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
},
|
||||
ReqID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
getQueryNodePolicy: defaultGetQueryNodePolicy,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
shardMgr: node.shardMgr,
|
||||
}
|
||||
|
||||
method := "Query"
|
||||
|
@ -3062,8 +3062,8 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
qc: node.queryCoord,
|
||||
ids: ids.IdArray,
|
||||
|
||||
getQueryNodePolicy: defaultGetQueryNodePolicy,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
shardMgr: node.shardMgr,
|
||||
}
|
||||
|
||||
items := []zapcore.Field{
|
||||
|
|
|
@ -55,7 +55,7 @@ type Cache interface {
|
|||
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
|
||||
// GetCollectionSchema get collection's schema.
|
||||
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
||||
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
|
||||
GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error)
|
||||
ClearShards(collectionName string)
|
||||
RemoveCollection(ctx context.Context, collectionName string)
|
||||
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID)
|
||||
|
@ -73,7 +73,7 @@ type collectionInfo struct {
|
|||
collID typeutil.UniqueID
|
||||
schema *schemapb.CollectionSchema
|
||||
partInfo map[string]*partitionInfo
|
||||
shardLeaders map[string][]queryNode
|
||||
shardLeaders map[string][]nodeInfo
|
||||
leaderMutex sync.Mutex
|
||||
createdTimestamp uint64
|
||||
createdUtcTimestamp uint64
|
||||
|
@ -82,10 +82,10 @@ type collectionInfo struct {
|
|||
|
||||
// CloneShardLeaders returns a copy of shard leaders
|
||||
// leaderMutex shall be accuired before invoking this method
|
||||
func (c *collectionInfo) CloneShardLeaders() map[string][]queryNode {
|
||||
m := make(map[string][]queryNode)
|
||||
func (c *collectionInfo) CloneShardLeaders() map[string][]nodeInfo {
|
||||
m := make(map[string][]nodeInfo)
|
||||
for channel, leaders := range c.shardLeaders {
|
||||
l := make([]queryNode, len(leaders))
|
||||
l := make([]nodeInfo, len(leaders))
|
||||
copy(l, leaders)
|
||||
m[channel] = l
|
||||
}
|
||||
|
@ -111,15 +111,16 @@ type MetaCache struct {
|
|||
credUsernameList []string // no need initialize when NewMetaCache
|
||||
mu sync.RWMutex
|
||||
credMut sync.RWMutex
|
||||
shardMgr *shardClientMgr
|
||||
}
|
||||
|
||||
// globalMetaCache is singleton instance of Cache
|
||||
var globalMetaCache Cache
|
||||
|
||||
// InitMetaCache initializes globalMetaCache
|
||||
func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) error {
|
||||
func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr *shardClientMgr) error {
|
||||
var err error
|
||||
globalMetaCache, err = NewMetaCache(rootCoord, queryCoord)
|
||||
globalMetaCache, err = NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -127,12 +128,13 @@ func InitMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) error
|
|||
}
|
||||
|
||||
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
|
||||
func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord) (*MetaCache, error) {
|
||||
func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr *shardClientMgr) (*MetaCache, error) {
|
||||
return &MetaCache{
|
||||
rootCoord: rootCoord,
|
||||
queryCoord: queryCoord,
|
||||
collInfo: map[string]*collectionInfo{},
|
||||
credMap: map[string]*internalpb.CredentialInfo{},
|
||||
shardMgr: shardMgr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -583,7 +585,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
|
|||
}
|
||||
|
||||
// GetShards update cache if withCache == false
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error) {
|
||||
info, err := m.GetCollectionInfo(ctx, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -601,7 +603,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
||||
zap.String("collectionName", collectionName))
|
||||
}
|
||||
|
||||
req := &querypb.GetShardLeadersRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_GetShardLeaders,
|
||||
|
@ -615,7 +616,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
childCtx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
err = retry.Do(childCtx, func() error {
|
||||
resp, err = qc.GetShardLeaders(ctx, req)
|
||||
resp, err = m.queryCoord.GetShardLeaders(ctx, req)
|
||||
if err != nil {
|
||||
return retry.Unrecoverable(err)
|
||||
}
|
||||
|
@ -629,31 +630,38 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
return fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetShardLeaders timeout, error: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
|
||||
}
|
||||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
// manipulate info in map, get map returns a copy of the information
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
info = m.collInfo[collectionName]
|
||||
// lock leader
|
||||
info.leaderMutex.Lock()
|
||||
defer info.leaderMutex.Unlock()
|
||||
oldShards := info.shardLeaders
|
||||
info.shardLeaders = shards
|
||||
info.leaderMutex.Unlock()
|
||||
m.mu.RUnlock()
|
||||
|
||||
return info.CloneShardLeaders(), nil
|
||||
// update refcnt in shardClientMgr
|
||||
ret := info.CloneShardLeaders()
|
||||
_ = m.shardMgr.UpdateShardLeaders(oldShards, ret)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
|
||||
shard2QueryNodes := make(map[string][]queryNode)
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo {
|
||||
shard2QueryNodes := make(map[string][]nodeInfo)
|
||||
|
||||
for _, leaders := range shardsLeaders {
|
||||
qns := make([]queryNode, len(leaders.GetNodeIds()))
|
||||
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
|
||||
|
||||
for j := range qns {
|
||||
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
|
||||
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
|
||||
}
|
||||
|
||||
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||
|
@ -666,12 +674,14 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
|
|||
func (m *MetaCache) ClearShards(collectionName string) {
|
||||
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
_, ok := m.collInfo[collectionName]
|
||||
|
||||
if !ok {
|
||||
return
|
||||
//var ret map[string][]nodeInfo
|
||||
info, ok := m.collInfo[collectionName]
|
||||
if ok {
|
||||
m.collInfo[collectionName].shardLeaders = nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
// delete refcnt in shardClientMgr
|
||||
if ok {
|
||||
_ = m.shardMgr.UpdateShardLeaders(info.shardLeaders, nil)
|
||||
}
|
||||
|
||||
m.collInfo[collectionName].shardLeaders = nil
|
||||
}
|
||||
|
|
|
@ -198,7 +198,8 @@ func TestMetaCache_GetCollection(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetCollectionID(ctx, "collection1")
|
||||
|
@ -245,7 +246,8 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
rootCoord.Error = true
|
||||
|
||||
|
@ -275,7 +277,8 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetCollectionID(ctx, "collection3")
|
||||
|
@ -290,7 +293,8 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetPartitionID(ctx, "collection1", "par1")
|
||||
|
@ -311,7 +315,8 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test the case where ShowPartitionsResponse is not aligned
|
||||
|
@ -340,35 +345,35 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
|
|||
|
||||
func TestMetaCache_GetShards(t *testing.T) {
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
qc := NewQueryCoordMock()
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, qc, shardMgr)
|
||||
require.Nil(t, err)
|
||||
|
||||
var (
|
||||
ctx = context.TODO()
|
||||
collectionName = "collection1"
|
||||
qc = NewQueryCoordMock()
|
||||
)
|
||||
qc.Init()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
|
||||
t.Run("No collection in meta cache", func(t *testing.T) {
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists", qc)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
||||
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
|
||||
qc.validShardLeaders = false
|
||||
shards, err := globalMetaCache.GetShards(ctx, false, collectionName, qc)
|
||||
shards, err := globalMetaCache.GetShards(ctx, false, collectionName)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
||||
t.Run("without shardLeaders in collection info", func(t *testing.T) {
|
||||
qc.validShardLeaders = true
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 1, len(shards))
|
||||
|
@ -377,7 +382,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
|
||||
// get from cache
|
||||
qc.validShardLeaders = false
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 1, len(shards))
|
||||
|
@ -387,14 +392,14 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
|
||||
func TestMetaCache_ClearShards(t *testing.T) {
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
qc := NewQueryCoordMock()
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, qc, mgr)
|
||||
require.Nil(t, err)
|
||||
|
||||
var (
|
||||
ctx = context.TODO()
|
||||
collectionName = "collection1"
|
||||
qc = NewQueryCoordMock()
|
||||
)
|
||||
qc.Init()
|
||||
qc.Start()
|
||||
|
@ -411,7 +416,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
t.Run("Clear valid collection valid cache", func(t *testing.T) {
|
||||
|
||||
qc.validShardLeaders = true
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, collectionName)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, shards)
|
||||
require.Equal(t, 1, len(shards))
|
||||
|
@ -420,7 +425,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
globalMetaCache.ClearShards(collectionName)
|
||||
|
||||
qc.validShardLeaders = false
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
@ -431,7 +436,8 @@ func TestMetaCache_LoadCache(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, mgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
t.Run("test IsCollectionLoaded", func(t *testing.T) {
|
||||
|
@ -474,7 +480,8 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
err := InitMetaCache(rootCoord, queryCoord)
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.Nil(t, err)
|
||||
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, "collection1")
|
||||
|
|
|
@ -89,7 +89,8 @@ type Proxy struct {
|
|||
|
||||
metricsCacheManager *metricsinfo.MetricsCacheManager
|
||||
|
||||
session *sessionutil.Session
|
||||
session *sessionutil.Session
|
||||
shardMgr *shardClientMgr
|
||||
|
||||
factory dependency.Factory
|
||||
|
||||
|
@ -110,6 +111,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
|||
cancel: cancel,
|
||||
factory: factory,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: newShardClientMgr(),
|
||||
}
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
|
||||
|
@ -219,7 +221,7 @@ func (node *Proxy) Init() error {
|
|||
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
log.Debug("init meta cache", zap.String("role", typeutil.ProxyRole))
|
||||
if err := InitMetaCache(node.rootCoord, node.queryCoord); err != nil {
|
||||
if err := InitMetaCache(node.rootCoord, node.queryCoord, node.shardMgr); err != nil {
|
||||
log.Warn("failed to init meta cache", zap.Error(err), zap.String("role", typeutil.ProxyRole))
|
||||
return err
|
||||
}
|
||||
|
@ -378,6 +380,10 @@ func (node *Proxy) Stop() error {
|
|||
|
||||
node.session.Revoke(time.Second)
|
||||
|
||||
if node.shardMgr != nil {
|
||||
node.shardMgr.Close()
|
||||
}
|
||||
|
||||
// https://github.com/milvus-io/milvus/issues/12282
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
|
||||
|
|
|
@ -3036,7 +3036,8 @@ func TestProxy_Import(t *testing.T) {
|
|||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
err := InitMetaCache(rc, qc)
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(rc, qc, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
rc.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
type queryNodeCreatorFunc func(ctx context.Context, addr string) (types.QueryNode, error)
|
||||
|
||||
type nodeInfo struct {
|
||||
nodeID UniqueID
|
||||
address string
|
||||
}
|
||||
|
||||
func (n nodeInfo) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d>", n.nodeID)
|
||||
}
|
||||
|
||||
var errClosed = errors.New("client is closed")
|
||||
|
||||
type shardClient struct {
|
||||
sync.RWMutex
|
||||
info nodeInfo
|
||||
client types.QueryNode
|
||||
isClosed bool
|
||||
refCnt int
|
||||
}
|
||||
|
||||
func (n *shardClient) getClient(ctx context.Context) (types.QueryNode, error) {
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
if n.isClosed {
|
||||
return nil, errClosed
|
||||
}
|
||||
return n.client, nil
|
||||
}
|
||||
|
||||
func (n *shardClient) inc() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if n.isClosed {
|
||||
return
|
||||
}
|
||||
n.refCnt++
|
||||
}
|
||||
|
||||
func (n *shardClient) close() {
|
||||
n.isClosed = true
|
||||
n.refCnt = 0
|
||||
if n.client != nil {
|
||||
n.client.Stop()
|
||||
n.client = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *shardClient) dec() bool {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if n.isClosed {
|
||||
return true
|
||||
}
|
||||
if n.refCnt > 0 {
|
||||
n.refCnt--
|
||||
}
|
||||
if n.refCnt == 0 {
|
||||
n.close()
|
||||
}
|
||||
return n.refCnt == 0
|
||||
}
|
||||
|
||||
func (n *shardClient) Close() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
n.close()
|
||||
}
|
||||
|
||||
func newShardClient(info *nodeInfo, client types.QueryNode) *shardClient {
|
||||
ret := &shardClient{
|
||||
info: nodeInfo{
|
||||
nodeID: info.nodeID,
|
||||
address: info.address,
|
||||
},
|
||||
client: client,
|
||||
refCnt: 1,
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type shardClientMgr struct {
|
||||
clients struct {
|
||||
sync.RWMutex
|
||||
data map[UniqueID]*shardClient
|
||||
}
|
||||
clientCreator queryNodeCreatorFunc
|
||||
}
|
||||
|
||||
// SessionOpt provides a way to set params in SessionManager
|
||||
type shardClientMgrOpt func(s *shardClientMgr)
|
||||
|
||||
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
|
||||
return func(s *shardClientMgr) { s.clientCreator = creator }
|
||||
}
|
||||
|
||||
func defaultShardClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
|
||||
return qnClient.NewClient(ctx, addr)
|
||||
}
|
||||
|
||||
// NewShardClientMgr creates a new shardClientMgr
|
||||
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgr {
|
||||
s := &shardClientMgr{
|
||||
clients: struct {
|
||||
sync.RWMutex
|
||||
data map[UniqueID]*shardClient
|
||||
}{data: make(map[UniqueID]*shardClient)},
|
||||
clientCreator: defaultShardClientCreator,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Warning this method may modify parameter `oldLeaders`
|
||||
func (c *shardClientMgr) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error {
|
||||
oldLocalMap := make(map[UniqueID]*nodeInfo)
|
||||
for _, nodes := range oldLeaders {
|
||||
for i := range nodes {
|
||||
n := &nodes[i]
|
||||
_, ok := oldLocalMap[n.nodeID]
|
||||
if !ok {
|
||||
oldLocalMap[n.nodeID] = n
|
||||
}
|
||||
}
|
||||
}
|
||||
newLocalMap := make(map[UniqueID]*nodeInfo)
|
||||
|
||||
for _, nodes := range newLeaders {
|
||||
for i := range nodes {
|
||||
n := &nodes[i]
|
||||
_, ok := oldLocalMap[n.nodeID]
|
||||
if !ok {
|
||||
_, ok2 := newLocalMap[n.nodeID]
|
||||
if !ok2 {
|
||||
newLocalMap[n.nodeID] = n
|
||||
}
|
||||
}
|
||||
delete(oldLocalMap, n.nodeID)
|
||||
}
|
||||
}
|
||||
c.clients.Lock()
|
||||
defer c.clients.Unlock()
|
||||
|
||||
for _, node := range newLocalMap {
|
||||
client, ok := c.clients.data[node.nodeID]
|
||||
if ok {
|
||||
client.inc()
|
||||
} else {
|
||||
// context.Background() is useless
|
||||
// TODO QueryNode NewClient remove ctx parameter
|
||||
// TODO Remove Init && Start interface in QueryNode client
|
||||
if c.clientCreator == nil {
|
||||
return fmt.Errorf("clientCreator function is nil")
|
||||
}
|
||||
shardClient, err := c.clientCreator(context.Background(), node.address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := newShardClient(node, shardClient)
|
||||
c.clients.data[node.nodeID] = client
|
||||
}
|
||||
}
|
||||
for _, node := range oldLocalMap {
|
||||
client, ok := c.clients.data[node.nodeID]
|
||||
if ok && client.dec() {
|
||||
delete(c.clients.data, node.nodeID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *shardClientMgr) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNode, error) {
|
||||
c.clients.RLock()
|
||||
client, ok := c.clients.data[nodeID]
|
||||
c.clients.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not find client of node %d", nodeID)
|
||||
}
|
||||
return client.getClient(ctx)
|
||||
}
|
||||
|
||||
// Close release clients
|
||||
func (c *shardClientMgr) Close() {
|
||||
c.clients.Lock()
|
||||
defer c.clients.Unlock()
|
||||
|
||||
for _, s := range c.clients.data {
|
||||
s.Close()
|
||||
}
|
||||
c.clients.data = make(map[UniqueID]*shardClient)
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo {
|
||||
leaders := make(map[string][]nodeInfo)
|
||||
nodeInfos := make([]nodeInfo, len(leaderIDs))
|
||||
for i, id := range leaderIDs {
|
||||
nodeInfos[i] = nodeInfo{
|
||||
nodeID: id,
|
||||
address: "fake",
|
||||
}
|
||||
}
|
||||
leaders[channel] = nodeInfos
|
||||
return leaders
|
||||
}
|
||||
|
||||
func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) {
|
||||
mgr := newShardClientMgr(withShardClientCreator(nil))
|
||||
mgr.clientCreator = nil
|
||||
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
|
||||
err := mgr.UpdateShardLeaders(nil, leaders)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) {
|
||||
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
|
||||
return &mock.QueryNodeClient{}, nil
|
||||
}
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
|
||||
_, err := mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.Error(t, err)
|
||||
|
||||
err = mgr.UpdateShardLeaders(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.Error(t, err)
|
||||
|
||||
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
|
||||
err = mgr.UpdateShardLeaders(leaders, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestShardClientMgr_UpdateShardLeaders_NonEmpty(t *testing.T) {
|
||||
mgr := newShardClientMgr()
|
||||
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
|
||||
err := mgr.UpdateShardLeaders(nil, leaders)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newLeaders := genShardLeaderInfo("c1", []UniqueID{2, 3})
|
||||
err = mgr.UpdateShardLeaders(leaders, newLeaders)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestShardClientMgr_UpdateShardLeaders_Ref(t *testing.T) {
|
||||
mgr := newShardClientMgr()
|
||||
leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
err := mgr.UpdateShardLeaders(nil, leaders)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
partLeaders := genShardLeaderInfo("c1", []UniqueID{1})
|
||||
|
||||
_, err := mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = mgr.UpdateShardLeaders(partLeaders, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = mgr.UpdateShardLeaders(partLeaders, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(1))
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(2))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = mgr.GetClient(context.Background(), UniqueID(3))
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -70,8 +70,9 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
|
|||
collectionID: collectionID,
|
||||
}
|
||||
|
||||
shardMgr := newShardClientMgr()
|
||||
// failed to get collection id.
|
||||
_ = InitMetaCache(rootCoord, queryCoord)
|
||||
_ = InitMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.Error(t, gist.Execute(ctx))
|
||||
rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
|
|
|
@ -5,51 +5,20 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
|
||||
|
||||
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
|
||||
|
||||
// TODO add another policy to enbale the use of cache
|
||||
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
|
||||
func defaultGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
qn, err := qnClient.NewClient(ctx, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := qn.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := qn.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return qn, nil
|
||||
}
|
||||
type pickShardPolicy func(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error
|
||||
|
||||
var (
|
||||
errBegin = errors.New("begin error")
|
||||
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||
)
|
||||
|
||||
type queryNode struct {
|
||||
nodeID UniqueID
|
||||
address string
|
||||
}
|
||||
|
||||
func (q queryNode) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
|
||||
}
|
||||
|
||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) {
|
||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]nodeInfo) {
|
||||
for channelID, leaders := range shardsLeaders {
|
||||
if len(leaders) <= 1 {
|
||||
continue
|
||||
|
@ -59,7 +28,7 @@ func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) {
|
|||
}
|
||||
}
|
||||
|
||||
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
|
||||
func roundRobinPolicy(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error {
|
||||
var (
|
||||
err = errBegin
|
||||
current = 0
|
||||
|
@ -75,7 +44,7 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||
zap.Int64("nodeID", currentID))
|
||||
}
|
||||
|
||||
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
|
||||
qn, err = mgr.GetClient(ctx, leaders[current].nodeID)
|
||||
if err != nil {
|
||||
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
|
||||
zap.Error(err))
|
||||
|
@ -83,7 +52,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||
continue
|
||||
}
|
||||
|
||||
defer qn.Stop()
|
||||
err = query(currentID, qn)
|
||||
if err != nil {
|
||||
log.Warn("fail to Query with shard leader",
|
||||
|
|
|
@ -11,11 +11,13 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/mock"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
||||
list := map[string][]queryNode{
|
||||
list := map[string][]nodeInfo{
|
||||
"channel-1": {
|
||||
{1, "addr1"},
|
||||
{2, "addr2"},
|
||||
|
@ -34,7 +36,7 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
|||
assert.Equal(t, "addr21", list["channel-2"][0].address)
|
||||
|
||||
t.Run("check print", func(t *testing.T) {
|
||||
qns := []queryNode{
|
||||
qns := []nodeInfo{
|
||||
{1, "addr1"},
|
||||
{2, "addr2"},
|
||||
{20, "addr20"},
|
||||
|
@ -54,10 +56,16 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
|||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
var (
|
||||
getQueryNodePolicy = mockGetQueryNodePolicy
|
||||
ctx = context.TODO()
|
||||
ctx = context.TODO()
|
||||
)
|
||||
|
||||
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
|
||||
return &mock.QueryNodeClient{}, nil
|
||||
}
|
||||
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
dummyLeaders := genShardLeaderInfo("c1", []UniqueID{-1, 1, 2, 3})
|
||||
mgr.UpdateShardLeaders(nil, dummyLeaders)
|
||||
t.Run("All fails", func(t *testing.T) {
|
||||
allFailTests := []struct {
|
||||
leaderIDs []UniqueID
|
||||
|
@ -73,13 +81,13 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
t.Run(test.description, func(t *testing.T) {
|
||||
query := (&mockQuery{isvalid: false}).query
|
||||
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
leaders = append(leaders, nodeInfo{ID, "random-addr"})
|
||||
|
||||
}
|
||||
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
err := roundRobinPolicy(ctx, mgr, query, leaders)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -98,12 +106,12 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
|
||||
for _, test := range allPassTests {
|
||||
query := (&mockQuery{isvalid: true}).query
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
leaders = append(leaders, nodeInfo{ID, "random-addr"})
|
||||
|
||||
}
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
err := roundRobinPolicy(ctx, mgr, query, leaders)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
|
@ -120,18 +128,18 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
|
||||
for _, test := range passAtLast {
|
||||
query := (&mockQuery{isvalid: true}).query
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
leaders := make([]nodeInfo, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
leaders = append(leaders, nodeInfo{ID, "random-addr"})
|
||||
|
||||
}
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
err := roundRobinPolicy(ctx, mgr, query, leaders)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func mockGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
func mockQueryNodeCreator(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
return &QueryNodeMock{address: address}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -50,15 +50,11 @@ type queryTask struct {
|
|||
runningGroup *errgroup.Group
|
||||
runningGroupCtx context.Context
|
||||
|
||||
getQueryNodePolicy getQueryNodePolicy
|
||||
queryShardPolicy pickShardPolicy
|
||||
queryShardPolicy pickShardPolicy
|
||||
shardMgr *shardClientMgr
|
||||
}
|
||||
|
||||
func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||
if t.getQueryNodePolicy == nil {
|
||||
t.getQueryNodePolicy = defaultGetQueryNodePolicy
|
||||
}
|
||||
|
||||
if t.queryShardPolicy == nil {
|
||||
t.queryShardPolicy = roundRobinPolicy
|
||||
}
|
||||
|
@ -240,11 +236,10 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
|||
defer tr.Elapse("done")
|
||||
|
||||
executeQuery := func(withCache bool) error {
|
||||
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
||||
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
|
||||
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
|
||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||
|
@ -353,7 +348,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||
func (t *queryTask) queryShard(ctx context.Context, leaders []nodeInfo, channelID string) error {
|
||||
query := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||
req := &querypb.QueryRequest{
|
||||
Req: t.RetrieveRequest,
|
||||
|
@ -378,7 +373,7 @@ func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channel
|
|||
return nil
|
||||
}
|
||||
|
||||
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
|
||||
err := t.queryShardPolicy(t.TraceCtx(), t.shardMgr, query, leaders)
|
||||
if err != nil {
|
||||
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
|
||||
return err
|
||||
|
|
|
@ -11,13 +11,12 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
@ -41,16 +40,18 @@ func TestQueryTask_all(t *testing.T) {
|
|||
hitNum = 10
|
||||
)
|
||||
|
||||
mockGetQueryNodePolicy := func(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
|
||||
err = InitMetaCache(rc, qc)
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
|
@ -125,8 +126,8 @@ func TestQueryTask_all(t *testing.T) {
|
|||
},
|
||||
qc: qc,
|
||||
|
||||
getQueryNodePolicy: mockGetQueryNodePolicy,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
queryShardPolicy: roundRobinPolicy,
|
||||
shardMgr: mgr,
|
||||
}
|
||||
for i := 0; i < len(fieldName2Types); i++ {
|
||||
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
|
||||
|
@ -181,7 +182,8 @@ func TestCheckIfLoaded(t *testing.T) {
|
|||
err = rc.Start()
|
||||
defer rc.Stop()
|
||||
require.NoError(t, err)
|
||||
err = InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = qc.Start()
|
||||
|
|
|
@ -50,18 +50,14 @@ type searchTask struct {
|
|||
runningGroup *errgroup.Group
|
||||
runningGroupCtx context.Context
|
||||
|
||||
getQueryNodePolicy getQueryNodePolicy
|
||||
searchShardPolicy pickShardPolicy
|
||||
searchShardPolicy pickShardPolicy
|
||||
shardMgr *shardClientMgr
|
||||
}
|
||||
|
||||
func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PreExecute")
|
||||
defer sp.Finish()
|
||||
|
||||
if t.getQueryNodePolicy == nil {
|
||||
t.getQueryNodePolicy = defaultGetQueryNodePolicy
|
||||
}
|
||||
|
||||
if t.searchShardPolicy == nil {
|
||||
t.searchShardPolicy = roundRobinPolicy
|
||||
}
|
||||
|
@ -278,11 +274,10 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
defer tr.Elapse("done")
|
||||
|
||||
executeSearch := func(withCache bool) error {
|
||||
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
||||
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
||||
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||
|
@ -299,28 +294,23 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
zap.String("shard channel", channelID),
|
||||
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
||||
|
||||
err := t.searchShard(t.runningGroupCtx, leaders, channelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return t.searchShard(t.runningGroupCtx, leaders, channelID)
|
||||
})
|
||||
}
|
||||
|
||||
err = t.runningGroup.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
err := executeSearch(WithCache)
|
||||
if err == errInvalidShardLeaders {
|
||||
log.Warn("invalid shard leaders from cache, updating shardleader caches and retry search")
|
||||
if err == errInvalidShardLeaders || funcutil.IsGrpcErr(err) {
|
||||
log.Warn("first search failed, updating shardleader caches and retry search", zap.Error(err))
|
||||
return executeSearch(WithoutCache)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error())
|
||||
return fmt.Errorf("fail to search on all shard leaders, err=%w", err)
|
||||
}
|
||||
|
||||
log.Info("Search Execute done.",
|
||||
log.Debug("Search Execute done.",
|
||||
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "search"))
|
||||
return nil
|
||||
}
|
||||
|
@ -415,7 +405,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||
func (t *searchTask) searchShard(ctx context.Context, leaders []nodeInfo, channelID string) error {
|
||||
|
||||
search := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||
req := &querypb.SearchRequest{
|
||||
|
@ -423,9 +413,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, chann
|
|||
DmlChannel: channelID,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
||||
result, err := qn.Search(ctx, req)
|
||||
if err != nil || result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode search returns error", zap.Int64("nodeID", nodeID),
|
||||
zap.Error(err))
|
||||
return errInvalidShardLeaders
|
||||
|
@ -435,12 +427,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, chann
|
|||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
t.resultBuf <- result
|
||||
return nil
|
||||
}
|
||||
|
||||
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
|
||||
err := t.searchShardPolicy(t.TraceCtx(), t.shardMgr, search, leaders)
|
||||
if err != nil {
|
||||
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
|
||||
return err
|
||||
|
|
|
@ -124,7 +124,8 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
err = rc.Start()
|
||||
defer rc.Stop()
|
||||
require.NoError(t, err)
|
||||
err = InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = qc.Start()
|
||||
|
@ -423,7 +424,8 @@ func TestSearchTaskV2_Execute(t *testing.T) {
|
|||
err = rc.Start()
|
||||
require.NoError(t, err)
|
||||
defer rc.Stop()
|
||||
err = InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = qc.Start()
|
||||
|
|
|
@ -1094,7 +1094,8 @@ func TestDropCollectionTask(t *testing.T) {
|
|||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
|
||||
master := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
|
@ -1181,7 +1182,8 @@ func TestHasCollectionTask(t *testing.T) {
|
|||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
prefix := "TestHasCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
@ -1266,7 +1268,8 @@ func TestDescribeCollectionTask(t *testing.T) {
|
|||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
@ -1328,7 +1331,8 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
|
|||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
@ -1392,7 +1396,8 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
|
|||
qc.Start()
|
||||
defer qc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(rc, qc, mgr)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
@ -1658,7 +1663,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
|
@ -1911,7 +1917,8 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc, qc)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
type mockProxy struct {
|
||||
}
|
||||
|
||||
func (m *mockProxy) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) ClearCredUsersCache(ctx context.Context, request *internalpb.ClearCredUsersCacheRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newMockProxy() types.Proxy {
|
||||
return &mockProxy{}
|
||||
}
|
||||
|
||||
func mockProxyCreator() proxyCreatorFunc {
|
||||
return func(ctx context.Context, addr string) (types.Proxy, error) {
|
||||
return newMockProxy(), nil
|
||||
}
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
var errDisposed = errors.New("client is disposed")
|
||||
|
||||
type NodeInfo struct {
|
||||
NodeID int64
|
||||
Address string
|
||||
}
|
||||
|
||||
type proxyCreatorFunc func(ctx context.Context, addr string) (types.Proxy, error)
|
||||
|
||||
type Session struct {
|
||||
sync.Mutex
|
||||
info *NodeInfo
|
||||
client types.Proxy
|
||||
clientCreator proxyCreatorFunc
|
||||
isDisposed bool
|
||||
}
|
||||
|
||||
func NewSession(info *NodeInfo, creator proxyCreatorFunc) *Session {
|
||||
return &Session{
|
||||
info: info,
|
||||
clientCreator: creator,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Session) GetOrCreateClient(ctx context.Context) (types.Proxy, error) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.isDisposed {
|
||||
return nil, errDisposed
|
||||
}
|
||||
|
||||
if n.client != nil {
|
||||
return n.client, nil
|
||||
}
|
||||
|
||||
if n.clientCreator == nil {
|
||||
return nil, fmt.Errorf("unable to create client for %s because of a nil client creator", n.info.Address)
|
||||
}
|
||||
|
||||
err := n.initClient(ctx)
|
||||
return n.client, err
|
||||
}
|
||||
|
||||
func (n *Session) initClient(ctx context.Context) (err error) {
|
||||
if n.client, err = n.clientCreator(ctx, n.info.Address); err != nil {
|
||||
return
|
||||
}
|
||||
if err = n.client.Init(); err != nil {
|
||||
return
|
||||
}
|
||||
return n.client.Start()
|
||||
}
|
||||
|
||||
func (n *Session) Dispose() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.client != nil {
|
||||
n.client.Stop()
|
||||
n.client = nil
|
||||
}
|
||||
n.isDisposed = true
|
||||
}
|
|
@ -1,152 +0,0 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
sessions struct {
|
||||
sync.RWMutex
|
||||
data map[int64]*Session
|
||||
}
|
||||
// sessions sync.Map // UniqueID -> Session
|
||||
sessionCreator proxyCreatorFunc
|
||||
}
|
||||
|
||||
// SessionOpt provides a way to set params in SessionManager
|
||||
type SessionOpt func(c *SessionManager)
|
||||
|
||||
func withSessionCreator(creator proxyCreatorFunc) SessionOpt {
|
||||
return func(c *SessionManager) { c.sessionCreator = creator }
|
||||
}
|
||||
|
||||
func defaultSessionCreator() proxyCreatorFunc {
|
||||
return func(ctx context.Context, addr string) (types.Proxy, error) {
|
||||
return grpcproxyclient.NewClient(ctx, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new SessionManager
|
||||
func NewSessionManager(options ...SessionOpt) *SessionManager {
|
||||
m := &SessionManager{
|
||||
sessions: struct {
|
||||
sync.RWMutex
|
||||
data map[int64]*Session
|
||||
}{data: make(map[int64]*Session)},
|
||||
sessionCreator: defaultSessionCreator(),
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// AddSession creates a new session
|
||||
func (c *SessionManager) AddSession(node *NodeInfo) {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
log.Info("add proxy session", zap.Int64("node", node.NodeID))
|
||||
session := NewSession(node, c.sessionCreator)
|
||||
c.sessions.data[node.NodeID] = session
|
||||
}
|
||||
|
||||
func (c *SessionManager) Startup(nodes []*NodeInfo) {
|
||||
for _, node := range nodes {
|
||||
c.AddSession(node)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteSession removes the node session
|
||||
func (c *SessionManager) DeleteSession(node *NodeInfo) {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
log.Info("delete proxy session", zap.Int64("node", node.NodeID))
|
||||
if session, ok := c.sessions.data[node.NodeID]; ok {
|
||||
session.Dispose()
|
||||
delete(c.sessions.data, node.NodeID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessions gets all node sessions
|
||||
func (c *SessionManager) GetSessions() []*Session {
|
||||
c.sessions.RLock()
|
||||
defer c.sessions.RUnlock()
|
||||
|
||||
ret := make([]*Session, 0, len(c.sessions.data))
|
||||
for _, s := range c.sessions.data {
|
||||
ret = append(ret, s)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *SessionManager) SendSearchResult(ctx context.Context, nodeID UniqueID, result *internalpb.SearchResults) error {
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to send search result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := cli.SendSearchResult(ctx, result)
|
||||
if err := funcutil.VerifyResponse(resp, err); err != nil {
|
||||
log.Warn("failed to send search result", zap.Int64("node", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("success to send search result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) SendRetrieveResult(ctx context.Context, nodeID UniqueID, result *internalpb.RetrieveResults) error {
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to send retrieve result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := cli.SendRetrieveResult(ctx, result)
|
||||
if err := funcutil.VerifyResponse(resp, err); err != nil {
|
||||
log.Warn("failed to send retrieve result", zap.Int64("node", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("success to send retrieve result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.Proxy, error) {
|
||||
c.sessions.RLock()
|
||||
session, ok := c.sessions.data[nodeID]
|
||||
c.sessions.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not find session of node %d", nodeID)
|
||||
}
|
||||
|
||||
return session.GetOrCreateClient(ctx)
|
||||
}
|
||||
|
||||
// Close release sessions
|
||||
func (c *SessionManager) Close() {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
|
||||
for _, s := range c.sessions.data {
|
||||
s.Dispose()
|
||||
}
|
||||
c.sessions.data = nil
|
||||
}
|
|
@ -111,9 +111,8 @@ type QueryNode struct {
|
|||
factory dependency.Factory
|
||||
scheduler *taskScheduler
|
||||
|
||||
session *sessionutil.Session
|
||||
eventCh <-chan *sessionutil.SessionEvent
|
||||
sessionManager *SessionManager
|
||||
session *sessionutil.Session
|
||||
eventCh <-chan *sessionutil.SessionEvent
|
||||
|
||||
vectorStorage storage.ChunkManager
|
||||
cacheStorage storage.ChunkManager
|
||||
|
@ -203,74 +202,6 @@ func (node *QueryNode) InitSegcore() {
|
|||
C.SegcoreSetIndexSliceSize(cIndexSliceSize)
|
||||
}
|
||||
|
||||
func (node *QueryNode) initServiceDiscovery() error {
|
||||
if node.session == nil {
|
||||
return errors.New("session is nil")
|
||||
}
|
||||
|
||||
sessions, rev, err := node.session.GetSessions(typeutil.ProxyRole)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode failed to init service discovery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("QueryNode success to get Proxy sessions", zap.Any("sessions", sessions))
|
||||
|
||||
nodes := make([]*NodeInfo, 0, len(sessions))
|
||||
for _, session := range sessions {
|
||||
info := &NodeInfo{
|
||||
NodeID: session.ServerID,
|
||||
Address: session.Address,
|
||||
}
|
||||
nodes = append(nodes, info)
|
||||
}
|
||||
|
||||
node.sessionManager.Startup(nodes)
|
||||
|
||||
node.eventCh = node.session.WatchServices(typeutil.ProxyRole, rev+1, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) watchService(ctx context.Context) {
|
||||
defer node.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("watch service shutdown")
|
||||
return
|
||||
case event, ok := <-node.eventCh:
|
||||
if !ok {
|
||||
// ErrCompacted is handled inside SessionWatcher
|
||||
log.Error("Session Watcher channel closed", zap.Int64("server id", node.session.ServerID))
|
||||
// need to call stop in separate goroutine
|
||||
go node.Stop()
|
||||
if node.session.TriggerKill {
|
||||
if p, err := os.FindProcess(os.Getpid()); err == nil {
|
||||
p.Signal(syscall.SIGINT)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
node.handleSessionEvent(ctx, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (node *QueryNode) handleSessionEvent(ctx context.Context, event *sessionutil.SessionEvent) {
|
||||
info := &NodeInfo{
|
||||
NodeID: event.Session.ServerID,
|
||||
Address: event.Session.Address,
|
||||
}
|
||||
switch event.EventType {
|
||||
case sessionutil.SessionAddEvent:
|
||||
node.sessionManager.AddSession(info)
|
||||
case sessionutil.SessionDelEvent:
|
||||
node.sessionManager.DeleteSession(info)
|
||||
default:
|
||||
log.Warn("receive unknown service event type",
|
||||
zap.Any("type", event.EventType))
|
||||
}
|
||||
}
|
||||
|
||||
// Init function init historical and streaming module to manage segments
|
||||
func (node *QueryNode) Init() error {
|
||||
var initError error = nil
|
||||
|
@ -315,9 +246,6 @@ func (node *QueryNode) Init() error {
|
|||
|
||||
node.InitSegcore()
|
||||
|
||||
// TODO: add session creator to node
|
||||
node.sessionManager = NewSessionManager(withSessionCreator(defaultSessionCreator()))
|
||||
|
||||
log.Info("query node init successfully",
|
||||
zap.Any("queryNodeID", Params.QueryNodeCfg.GetNodeID()),
|
||||
zap.Any("IP", Params.QueryNodeCfg.QueryNodeIP),
|
||||
|
@ -337,14 +265,6 @@ func (node *QueryNode) Start() error {
|
|||
go node.watchChangeInfo()
|
||||
//go node.statsService.start()
|
||||
|
||||
// watch proxy
|
||||
if err := node.initServiceDiscovery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node.wg.Add(1)
|
||||
go node.watchService(node.queryNodeLoopCtx)
|
||||
|
||||
// create shardClusterService for shardLeader functions.
|
||||
node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node)
|
||||
// create shard-level query service
|
||||
|
|
|
@ -22,10 +22,8 @@ import (
|
|||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -40,7 +38,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
)
|
||||
|
||||
var embedetcdServer *embed.Etcd
|
||||
|
@ -331,83 +328,6 @@ func TestQueryNode_watchChangeInfo(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestQueryNode_watchService(t *testing.T) {
|
||||
t.Run("watch channel closed", func(t *testing.T) {
|
||||
ech := make(chan *sessionutil.SessionEvent)
|
||||
qn := &QueryNode{
|
||||
session: &sessionutil.Session{
|
||||
TriggerKill: true,
|
||||
ServerID: 0,
|
||||
},
|
||||
wg: sync.WaitGroup{},
|
||||
eventCh: ech,
|
||||
queryNodeLoopCancel: func() {},
|
||||
}
|
||||
flag := false
|
||||
closed := false
|
||||
|
||||
sigDone := make(chan struct{}, 1)
|
||||
sigQuit := make(chan struct{}, 1)
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, syscall.SIGINT)
|
||||
|
||||
defer signal.Reset(syscall.SIGINT)
|
||||
|
||||
qn.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
qn.watchService(context.Background())
|
||||
flag = true
|
||||
sigDone <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
<-sc
|
||||
closed = true
|
||||
sigQuit <- struct{}{}
|
||||
}()
|
||||
|
||||
close(ech)
|
||||
<-sigDone
|
||||
<-sigQuit
|
||||
assert.True(t, flag)
|
||||
assert.True(t, closed)
|
||||
})
|
||||
|
||||
t.Run("context done", func(t *testing.T) {
|
||||
ech := make(chan *sessionutil.SessionEvent)
|
||||
qn := &QueryNode{
|
||||
session: &sessionutil.Session{
|
||||
TriggerKill: true,
|
||||
ServerID: 0,
|
||||
},
|
||||
wg: sync.WaitGroup{},
|
||||
eventCh: ech,
|
||||
}
|
||||
flag := false
|
||||
|
||||
sigDone := make(chan struct{}, 1)
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, syscall.SIGINT)
|
||||
|
||||
defer signal.Reset(syscall.SIGINT)
|
||||
|
||||
qn.wg.Add(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
qn.watchService(ctx)
|
||||
flag = true
|
||||
sigDone <- struct{}{}
|
||||
}()
|
||||
|
||||
assert.False(t, flag)
|
||||
cancel()
|
||||
<-sigDone
|
||||
assert.True(t, flag)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryNode_validateChangeChannel(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
|
|
|
@ -39,6 +39,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
|
||||
grpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// CheckGrpcReady wait for context timeout, or wait 100ms then send nil to targetCh
|
||||
|
@ -352,3 +354,11 @@ func ReadBinary(endian binary.ByteOrder, bs []byte, receiver interface{}) error
|
|||
buf := bytes.NewReader(bs)
|
||||
return binary.Read(buf, endian, receiver)
|
||||
}
|
||||
|
||||
func IsGrpcErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := grpcStatus.FromError(err)
|
||||
return ok
|
||||
}
|
||||
|
|
|
@ -32,6 +32,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
grpcCodes "google.golang.org/grpc/codes"
|
||||
grpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type MockComponent struct {
|
||||
|
@ -463,3 +465,28 @@ func Test_ReadBinary(t *testing.T) {
|
|||
assert.NoError(t, ReadBinary(endian, bs, &fs))
|
||||
assert.ElementsMatch(t, []float32{0, 0}, fs)
|
||||
}
|
||||
|
||||
func TestIsGrpcErr(t *testing.T) {
|
||||
var err1 error
|
||||
assert.False(t, IsGrpcErr(err1))
|
||||
|
||||
err1 = errors.New("error")
|
||||
assert.False(t, IsGrpcErr(err1))
|
||||
|
||||
bgCtx := context.Background()
|
||||
ctx1, cancel1 := context.WithCancel(bgCtx)
|
||||
cancel1()
|
||||
assert.False(t, IsGrpcErr(ctx1.Err()))
|
||||
|
||||
timeout := 20 * time.Millisecond
|
||||
ctx1, cancel1 = context.WithTimeout(bgCtx, timeout)
|
||||
time.Sleep(timeout * 2)
|
||||
assert.False(t, IsGrpcErr(ctx1.Err()))
|
||||
cancel1()
|
||||
|
||||
err1 = grpcStatus.Error(grpcCodes.Canceled, "test")
|
||||
assert.True(t, IsGrpcErr(err1))
|
||||
|
||||
err1 = grpcStatus.Error(grpcCodes.Unavailable, "test")
|
||||
assert.True(t, IsGrpcErr(err1))
|
||||
}
|
||||
|
|
|
@ -203,17 +203,14 @@ func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
// status.Error(codes.Canceled, context.Canceled.Error()
|
||||
// if err2 == context.Canceled || err2 == context.DeadlineExceeded {
|
||||
// return nil, err2
|
||||
// }
|
||||
|
||||
if !funcutil.CheckCtxValid(ctx) {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
if !funcutil.IsGrpcErr(err2) {
|
||||
log.Debug("ClientBase:isNotGrpcErr", zap.Error(err2))
|
||||
return nil, err2
|
||||
}
|
||||
log.Debug(c.GetRole()+" ClientBase grpc error, start to reset connection", zap.Error(err2))
|
||||
|
||||
c.resetConnection(client)
|
||||
return ret, err2
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue