Make proxy use roundrobin to choose replica (#17063)

Fixes: #17055

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
pull/17071/head
XuanYang-cn 2022-05-17 22:35:57 +08:00 committed by GitHub
parent b37b87eb97
commit 127dd34b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 55 deletions

View File

@ -712,7 +712,7 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error {
toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
if toReleaseChannel == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
}
toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
@ -731,7 +731,7 @@ func (c *ChannelManager) toDelete(nodeID UniqueID, channelName string) error {
ch := c.getChannelByNodeAndName(nodeID, channelName)
if ch == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
}
if !c.isMarkedDrop(channelName) {

View File

@ -53,7 +53,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
ClearShards(collectionName string)
RemoveCollection(ctx context.Context, collectionName string)
RemovePartition(ctx context.Context, collectionName string, partitionName string)
@ -70,7 +70,7 @@ type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo
shardLeaders []*querypb.ShardLeadersList
shardLeaders map[string][]queryNode
createdTimestamp uint64
createdUtcTimestamp uint64
}
@ -528,7 +528,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
}
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil {
return nil, err
@ -536,7 +536,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
if withCache {
if len(info.shardLeaders) > 0 {
return info.shardLeaders, nil
shards := updateShardsWithRoundRobin(info.shardLeaders)
m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards
m.mu.Unlock()
return shards, nil
}
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
@ -557,7 +562,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
}
shards := resp.GetShards()
shards := parseShardLeaderList2QueryNode(resp.GetShards())
shards = updateShardsWithRoundRobin(shards)
m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards
@ -566,6 +573,22 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return shards, nil
}
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
shard2QueryNodes := make(map[string][]queryNode)
for _, leaders := range shardsLeaders {
qns := make([]queryNode, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns
}
return shard2QueryNodes
}
// ClearShards clear the shard leader cache of a collection
func (m *MetaCache) ClearShards(collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))

View File

@ -344,8 +344,8 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
assert.Equal(t, 3, len(shards["channel-1"]))
// get from cache
qc.validShardLeaders = false
@ -353,8 +353,7 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
assert.Equal(t, 3, len(shards["channel-1"]))
})
}
@ -387,8 +386,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards))
require.Equal(t, 3, len(shards[0].GetNodeAddrs()))
require.Equal(t, 3, len(shards[0].GetNodeIds()))
require.Equal(t, 3, len(shards["channel-1"]))
globalMetaCache.ClearShards(collectionName)

View File

@ -3,11 +3,11 @@ package proxy
import (
"context"
"errors"
"fmt"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
@ -15,7 +15,7 @@ import (
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
// TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
@ -40,23 +40,45 @@ var (
errInvalidShardLeaders = errors.New("Invalid shard leader")
)
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
type queryNode struct {
nodeID UniqueID
address string
}
func (q queryNode) String() string {
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
}
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {
for channelID, leaders := range shardsLeaders {
if len(leaders) <= 1 {
continue
}
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
}
return shardsLeaders
}
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
var (
err = errBegin
current = 0
qn types.QueryNode
)
replicaNum := len(leaders.GetNodeIds())
replicaNum := len(leaders)
for err != nil && current < replicaNum {
currentID := leaders.GetNodeIds()[current]
currentID := leaders[current].nodeID
if err != errBegin {
log.Warn("retry with another QueryNode",
zap.Int("retries numbers", current),
zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
zap.Int64("nodeID", currentID))
}
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err))
@ -68,7 +90,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
err = query(currentID, qn)
if err != nil {
log.Warn("fail to Query with shard leader",
zap.String("leader", leaders.GetChannelName()),
zap.Int64("nodeID", currentID),
zap.Error(err))
}
@ -76,9 +97,8 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
}
if current == replicaNum && err != nil {
log.Warn("no shard leaders available for channel",
zap.String("channel name", leaders.GetChannelName()),
zap.Int64s("leaders", leaders.GetNodeIds()), zap.Error(err))
log.Warn("no shard leaders available",
zap.String("leaders", fmt.Sprintf("%v", leaders)), zap.Error(err))
// needs to return the error from query
return err
}

View File

@ -5,11 +5,53 @@ import (
"fmt"
"testing"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestUpdateShardsWithRoundRobin(t *testing.T) {
in := map[string][]queryNode{
"channel-1": {
{1, "addr1"},
{2, "addr2"},
},
"channel-2": {
{20, "addr20"},
{21, "addr21"},
},
}
out := updateShardsWithRoundRobin(in)
assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
assert.Equal(t, "addr2", out["channel-1"][0].address)
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
assert.Equal(t, "addr21", out["channel-2"][0].address)
t.Run("check print", func(t *testing.T) {
qns := []queryNode{
{1, "addr1"},
{2, "addr2"},
{20, "addr20"},
{21, "addr21"},
}
res := fmt.Sprintf("list: %v", qns)
log.Debug("Check String func",
zap.Any("Any", qns),
zap.Any("ok", qns[0]),
zap.String("ok2", res),
)
})
}
func TestRoundRobinPolicy(t *testing.T) {
var (
getQueryNodePolicy = mockGetQueryNodePolicy
@ -31,11 +73,12 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Run(test.description, func(t *testing.T) {
query := (&mockQuery{isvalid: false}).query
leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.Error(t, err)
})
@ -55,10 +98,10 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range allPassTests {
query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err)
@ -77,10 +120,10 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range passAtLast {
query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})
}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err)

View File

@ -244,16 +244,17 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
for _, shard := range shards {
s := shard
for channelID, leaders := range shards {
channelID := channelID
leaders := leaders
t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard",
zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName),
zap.String("shard channel", s.GetChannelName()),
zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
err := t.queryShard(t.runningGroupCtx, s)
err := t.queryShard(t.runningGroupCtx, leaders, channelID)
if err != nil {
return err
}
@ -344,12 +345,12 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
return nil
}
func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
query := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.QueryRequest{
Req: t.RetrieveRequest,
IsShardLeader: true,
DmlChannel: leaders.GetChannelName(),
DmlChannel: channelID,
}
result, err := qn.Query(ctx, req)
@ -364,14 +365,14 @@ func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeader
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
}
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", leaders.GetChannelName()))
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
t.resultBuf <- result
return nil
}
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
if err != nil {
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders.GetNodeIds()))
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
return err
}

View File

@ -258,27 +258,28 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer tr.Elapse("done")
executeSearch := func(withCache bool) error {
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
if err != nil {
return err
}
t.resultBuf = make(chan *internalpb.SearchResults, len(shards))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shards))
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
// TODO: try to merge rpc send to different shard leaders.
// If two shard leader is on the same querynode maybe we should merge request to save rpc
for _, shard := range shards {
s := shard
for channelID, leaders := range shard2Leaders {
channelID := channelID
leaders := leaders
t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard",
zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName),
zap.String("shard channel", s.GetChannelName()),
zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
err := t.searchShard(t.runningGroupCtx, s)
err := t.searchShard(t.runningGroupCtx, leaders, channelID)
if err != nil {
return err
}
@ -393,13 +394,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return nil
}
func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
search := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.SearchRequest{
Req: t.SearchRequest,
IsShardLeader: true,
DmlChannel: leaders.GetChannelName(),
DmlChannel: channelID,
}
result, err := qn.Search(ctx, req)
@ -420,7 +421,7 @@ func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLead
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
if err != nil {
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders.GetNodeIds()))
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
return err
}