mirror of https://github.com/milvus-io/milvus.git
Make proxy use roundrobin to choose replica (#17063)
Fixes: #17055 Signed-off-by: yangxuan <xuan.yang@zilliz.com>pull/17071/head
parent
b37b87eb97
commit
127dd34b37
|
@ -712,7 +712,7 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error {
|
|||
|
||||
toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
|
||||
if toReleaseChannel == nil {
|
||||
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
|
||||
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
|
||||
}
|
||||
|
||||
toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
|
||||
|
@ -731,7 +731,7 @@ func (c *ChannelManager) toDelete(nodeID UniqueID, channelName string) error {
|
|||
|
||||
ch := c.getChannelByNodeAndName(nodeID, channelName)
|
||||
if ch == nil {
|
||||
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
|
||||
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
|
||||
}
|
||||
|
||||
if !c.isMarkedDrop(channelName) {
|
||||
|
|
|
@ -53,7 +53,7 @@ type Cache interface {
|
|||
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
|
||||
// GetCollectionSchema get collection's schema.
|
||||
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
||||
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
|
||||
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
|
||||
ClearShards(collectionName string)
|
||||
RemoveCollection(ctx context.Context, collectionName string)
|
||||
RemovePartition(ctx context.Context, collectionName string, partitionName string)
|
||||
|
@ -70,7 +70,7 @@ type collectionInfo struct {
|
|||
collID typeutil.UniqueID
|
||||
schema *schemapb.CollectionSchema
|
||||
partInfo map[string]*partitionInfo
|
||||
shardLeaders []*querypb.ShardLeadersList
|
||||
shardLeaders map[string][]queryNode
|
||||
createdTimestamp uint64
|
||||
createdUtcTimestamp uint64
|
||||
}
|
||||
|
@ -528,7 +528,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
|
|||
}
|
||||
|
||||
// GetShards update cache if withCache == false
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
|
||||
info, err := m.GetCollectionInfo(ctx, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -536,7 +536,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
|
||||
if withCache {
|
||||
if len(info.shardLeaders) > 0 {
|
||||
return info.shardLeaders, nil
|
||||
shards := updateShardsWithRoundRobin(info.shardLeaders)
|
||||
|
||||
m.mu.Lock()
|
||||
m.collInfo[collectionName].shardLeaders = shards
|
||||
m.mu.Unlock()
|
||||
return shards, nil
|
||||
}
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
||||
zap.String("collectionName", collectionName))
|
||||
|
@ -557,7 +562,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
|
||||
}
|
||||
|
||||
shards := resp.GetShards()
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
shards = updateShardsWithRoundRobin(shards)
|
||||
|
||||
m.mu.Lock()
|
||||
m.collInfo[collectionName].shardLeaders = shards
|
||||
|
@ -566,6 +573,22 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
return shards, nil
|
||||
}
|
||||
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
|
||||
shard2QueryNodes := make(map[string][]queryNode)
|
||||
|
||||
for _, leaders := range shardsLeaders {
|
||||
qns := make([]queryNode, len(leaders.GetNodeIds()))
|
||||
|
||||
for j := range qns {
|
||||
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
|
||||
}
|
||||
|
||||
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||
}
|
||||
|
||||
return shard2QueryNodes
|
||||
}
|
||||
|
||||
// ClearShards clear the shard leader cache of a collection
|
||||
func (m *MetaCache) ClearShards(collectionName string) {
|
||||
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
|
||||
|
|
|
@ -344,8 +344,8 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 1, len(shards))
|
||||
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
||||
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
|
||||
|
||||
assert.Equal(t, 3, len(shards["channel-1"]))
|
||||
|
||||
// get from cache
|
||||
qc.validShardLeaders = false
|
||||
|
@ -353,8 +353,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 1, len(shards))
|
||||
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
||||
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
|
||||
assert.Equal(t, 3, len(shards["channel-1"]))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -387,8 +386,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NotEmpty(t, shards)
|
||||
require.Equal(t, 1, len(shards))
|
||||
require.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
||||
require.Equal(t, 3, len(shards[0].GetNodeIds()))
|
||||
require.Equal(t, 3, len(shards["channel-1"]))
|
||||
|
||||
globalMetaCache.ClearShards(collectionName)
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
|
||||
|
||||
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
|
||||
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
|
||||
|
||||
// TODO add another policy to enbale the use of cache
|
||||
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
|
||||
|
@ -40,23 +40,45 @@ var (
|
|||
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||
)
|
||||
|
||||
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
|
||||
type queryNode struct {
|
||||
nodeID UniqueID
|
||||
address string
|
||||
}
|
||||
|
||||
func (q queryNode) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
|
||||
}
|
||||
|
||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {
|
||||
|
||||
for channelID, leaders := range shardsLeaders {
|
||||
if len(leaders) <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
|
||||
}
|
||||
|
||||
return shardsLeaders
|
||||
}
|
||||
|
||||
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
|
||||
var (
|
||||
err = errBegin
|
||||
current = 0
|
||||
qn types.QueryNode
|
||||
)
|
||||
replicaNum := len(leaders.GetNodeIds())
|
||||
replicaNum := len(leaders)
|
||||
|
||||
for err != nil && current < replicaNum {
|
||||
currentID := leaders.GetNodeIds()[current]
|
||||
currentID := leaders[current].nodeID
|
||||
if err != errBegin {
|
||||
log.Warn("retry with another QueryNode",
|
||||
zap.Int("retries numbers", current),
|
||||
zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
|
||||
zap.Int64("nodeID", currentID))
|
||||
}
|
||||
|
||||
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
|
||||
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
|
||||
if err != nil {
|
||||
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
|
||||
zap.Error(err))
|
||||
|
@ -68,7 +90,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||
err = query(currentID, qn)
|
||||
if err != nil {
|
||||
log.Warn("fail to Query with shard leader",
|
||||
zap.String("leader", leaders.GetChannelName()),
|
||||
zap.Int64("nodeID", currentID),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
@ -76,9 +97,8 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||
}
|
||||
|
||||
if current == replicaNum && err != nil {
|
||||
log.Warn("no shard leaders available for channel",
|
||||
zap.String("channel name", leaders.GetChannelName()),
|
||||
zap.Int64s("leaders", leaders.GetNodeIds()), zap.Error(err))
|
||||
log.Warn("no shard leaders available",
|
||||
zap.String("leaders", fmt.Sprintf("%v", leaders)), zap.Error(err))
|
||||
// needs to return the error from query
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,11 +5,53 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
||||
in := map[string][]queryNode{
|
||||
"channel-1": {
|
||||
{1, "addr1"},
|
||||
{2, "addr2"},
|
||||
},
|
||||
"channel-2": {
|
||||
{20, "addr20"},
|
||||
{21, "addr21"},
|
||||
},
|
||||
}
|
||||
|
||||
out := updateShardsWithRoundRobin(in)
|
||||
|
||||
assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
|
||||
assert.Equal(t, "addr2", out["channel-1"][0].address)
|
||||
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
|
||||
assert.Equal(t, "addr21", out["channel-2"][0].address)
|
||||
|
||||
t.Run("check print", func(t *testing.T) {
|
||||
qns := []queryNode{
|
||||
{1, "addr1"},
|
||||
{2, "addr2"},
|
||||
{20, "addr20"},
|
||||
{21, "addr21"},
|
||||
}
|
||||
|
||||
res := fmt.Sprintf("list: %v", qns)
|
||||
|
||||
log.Debug("Check String func",
|
||||
zap.Any("Any", qns),
|
||||
zap.Any("ok", qns[0]),
|
||||
zap.String("ok2", res),
|
||||
)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
var (
|
||||
getQueryNodePolicy = mockGetQueryNodePolicy
|
||||
|
@ -31,11 +73,12 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
t.Run(test.description, func(t *testing.T) {
|
||||
query := (&mockQuery{isvalid: false}).query
|
||||
|
||||
leaders := &querypb.ShardLeadersList{
|
||||
ChannelName: t.Name(),
|
||||
NodeIds: test.leaderIDs,
|
||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
|
||||
}
|
||||
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
@ -55,10 +98,10 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
|
||||
for _, test := range allPassTests {
|
||||
query := (&mockQuery{isvalid: true}).query
|
||||
leaders := &querypb.ShardLeadersList{
|
||||
ChannelName: t.Name(),
|
||||
NodeIds: test.leaderIDs,
|
||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
|
||||
}
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
require.NoError(t, err)
|
||||
|
@ -77,10 +120,10 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
|
||||
for _, test := range passAtLast {
|
||||
query := (&mockQuery{isvalid: true}).query
|
||||
leaders := &querypb.ShardLeadersList{
|
||||
ChannelName: t.Name(),
|
||||
NodeIds: test.leaderIDs,
|
||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
||||
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||
for _, ID := range test.leaderIDs {
|
||||
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||
|
||||
}
|
||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -244,16 +244,17 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
|||
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
|
||||
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
|
||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||
for _, shard := range shards {
|
||||
s := shard
|
||||
for channelID, leaders := range shards {
|
||||
channelID := channelID
|
||||
leaders := leaders
|
||||
t.runningGroup.Go(func() error {
|
||||
log.Debug("proxy starting to query one shard",
|
||||
zap.Int64("collectionID", t.CollectionID),
|
||||
zap.String("collection name", t.collectionName),
|
||||
zap.String("shard channel", s.GetChannelName()),
|
||||
zap.String("shard channel", channelID),
|
||||
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
||||
|
||||
err := t.queryShard(t.runningGroupCtx, s)
|
||||
err := t.queryShard(t.runningGroupCtx, leaders, channelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -344,12 +345,12 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
|
||||
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||
query := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||
req := &querypb.QueryRequest{
|
||||
Req: t.RetrieveRequest,
|
||||
IsShardLeader: true,
|
||||
DmlChannel: leaders.GetChannelName(),
|
||||
DmlChannel: channelID,
|
||||
}
|
||||
|
||||
result, err := qn.Query(ctx, req)
|
||||
|
@ -364,14 +365,14 @@ func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeader
|
|||
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", leaders.GetChannelName()))
|
||||
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
|
||||
t.resultBuf <- result
|
||||
return nil
|
||||
}
|
||||
|
||||
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
|
||||
if err != nil {
|
||||
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders.GetNodeIds()))
|
||||
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -258,27 +258,28 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
defer tr.Elapse("done")
|
||||
|
||||
executeSearch := func(withCache bool) error {
|
||||
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
||||
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.resultBuf = make(chan *internalpb.SearchResults, len(shards))
|
||||
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shards))
|
||||
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
||||
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||
|
||||
// TODO: try to merge rpc send to different shard leaders.
|
||||
// If two shard leader is on the same querynode maybe we should merge request to save rpc
|
||||
for _, shard := range shards {
|
||||
s := shard
|
||||
for channelID, leaders := range shard2Leaders {
|
||||
channelID := channelID
|
||||
leaders := leaders
|
||||
t.runningGroup.Go(func() error {
|
||||
log.Debug("proxy starting to query one shard",
|
||||
zap.Int64("collectionID", t.CollectionID),
|
||||
zap.String("collection name", t.collectionName),
|
||||
zap.String("shard channel", s.GetChannelName()),
|
||||
zap.String("shard channel", channelID),
|
||||
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
||||
|
||||
err := t.searchShard(t.runningGroupCtx, s)
|
||||
err := t.searchShard(t.runningGroupCtx, leaders, channelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -393,13 +394,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
|
||||
func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||
|
||||
search := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||
req := &querypb.SearchRequest{
|
||||
Req: t.SearchRequest,
|
||||
IsShardLeader: true,
|
||||
DmlChannel: leaders.GetChannelName(),
|
||||
DmlChannel: channelID,
|
||||
}
|
||||
|
||||
result, err := qn.Search(ctx, req)
|
||||
|
@ -420,7 +421,7 @@ func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLead
|
|||
|
||||
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
|
||||
if err != nil {
|
||||
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders.GetNodeIds()))
|
||||
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue