milvus/internal/proxy/shard_client.go

223 lines
4.9 KiB
Go

package proxy
import (
"context"
"fmt"
"sync"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
)
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
type nodeInfo struct {
nodeID UniqueID
address string
}
func (n nodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d>", n.nodeID)
}
var errClosed = errors.New("client is closed")
type shardClient struct {
sync.RWMutex
info nodeInfo
client types.QueryNodeClient
isClosed bool
refCnt int
}
func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) {
n.RLock()
defer n.RUnlock()
if n.isClosed {
return nil, errClosed
}
return n.client, nil
}
func (n *shardClient) inc() {
n.Lock()
defer n.Unlock()
if n.isClosed {
return
}
n.refCnt++
}
func (n *shardClient) close() {
n.isClosed = true
n.refCnt = 0
if n.client != nil {
if err := n.client.Close(); err != nil {
log.Warn("close grpc client failed", zap.Error(err))
}
n.client = nil
}
}
func (n *shardClient) dec() bool {
n.Lock()
defer n.Unlock()
if n.isClosed {
return true
}
if n.refCnt > 0 {
n.refCnt--
}
if n.refCnt == 0 {
n.close()
}
return n.refCnt == 0
}
func (n *shardClient) Close() {
n.Lock()
defer n.Unlock()
n.close()
}
func newShardClient(info *nodeInfo, client types.QueryNodeClient) *shardClient {
ret := &shardClient{
info: nodeInfo{
nodeID: info.nodeID,
address: info.address,
},
client: client,
refCnt: 1,
}
return ret
}
type shardClientMgr interface {
GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error)
UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error
Close()
SetClientCreatorFunc(creator queryNodeCreatorFunc)
}
type shardClientMgrImpl struct {
clients struct {
sync.RWMutex
data map[UniqueID]*shardClient
}
clientCreator queryNodeCreatorFunc
}
// SessionOpt provides a way to set params in SessionManager
type shardClientMgrOpt func(s shardClientMgr)
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s shardClientMgr) { s.SetClientCreatorFunc(creator) }
}
func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID)
}
// NewShardClientMgr creates a new shardClientMgr
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl {
s := &shardClientMgrImpl{
clients: struct {
sync.RWMutex
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultQueryNodeClientCreator,
}
for _, opt := range options {
opt(s)
}
return s
}
func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
c.clientCreator = creator
}
// Warning this method may modify parameter `oldLeaders`
func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error {
oldLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range oldLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
oldLocalMap[n.nodeID] = n
}
}
}
newLocalMap := make(map[UniqueID]*nodeInfo)
for _, nodes := range newLeaders {
for i := range nodes {
n := &nodes[i]
_, ok := oldLocalMap[n.nodeID]
if !ok {
_, ok2 := newLocalMap[n.nodeID]
if !ok2 {
newLocalMap[n.nodeID] = n
}
}
delete(oldLocalMap, n.nodeID)
}
}
c.clients.Lock()
defer c.clients.Unlock()
for _, node := range newLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok {
client.inc()
} else {
// context.Background() is useless
// TODO QueryNode NewClient remove ctx parameter
// TODO Remove Init && Start interface in QueryNode client
if c.clientCreator == nil {
return fmt.Errorf("clientCreator function is nil")
}
shardClient, err := c.clientCreator(context.Background(), node.address, node.nodeID)
if err != nil {
return err
}
client := newShardClient(node, shardClient)
c.clients.data[node.nodeID] = client
}
}
for _, node := range oldLocalMap {
client, ok := c.clients.data[node.nodeID]
if ok && client.dec() {
delete(c.clients.data, node.nodeID)
}
}
return nil
}
func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) {
c.clients.RLock()
client, ok := c.clients.data[nodeID]
c.clients.RUnlock()
if !ok {
return nil, fmt.Errorf("can not find client of node %d", nodeID)
}
return client.getClient(ctx)
}
// Close release clients
func (c *shardClientMgrImpl) Close() {
c.clients.Lock()
defer c.clients.Unlock()
for _, s := range c.clients.data {
s.Close()
}
c.clients.data = make(map[UniqueID]*shardClient)
}