mirror of https://github.com/milvus-io/milvus.git
Fix shardLeader cache concurrent access (#17120)
Fix write map without mutex control Also GetShards returns a copy of leader list instead of original one Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/17129/head
parent
2f8a7e7793
commit
90ee23df22
|
@ -72,11 +72,24 @@ type collectionInfo struct {
|
|||
schema *schemapb.CollectionSchema
|
||||
partInfo map[string]*partitionInfo
|
||||
shardLeaders map[string][]queryNode
|
||||
leaderMutex sync.Mutex
|
||||
createdTimestamp uint64
|
||||
createdUtcTimestamp uint64
|
||||
isLoaded bool
|
||||
}
|
||||
|
||||
// CloneShardLeaders returns a copy of shard leaders
|
||||
// leaderMutex shall be accuired before invoking this method
|
||||
func (c *collectionInfo) CloneShardLeaders() map[string][]queryNode {
|
||||
m := make(map[string][]queryNode)
|
||||
for channel, leaders := range c.shardLeaders {
|
||||
l := make([]queryNode, len(leaders))
|
||||
copy(l, leaders)
|
||||
m[channel] = l
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type partitionInfo struct {
|
||||
partitionID typeutil.UniqueID
|
||||
createdTimestamp uint64
|
||||
|
@ -584,11 +597,11 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
|
||||
if withCache {
|
||||
if len(info.shardLeaders) > 0 {
|
||||
shards := updateShardsWithRoundRobin(info.shardLeaders)
|
||||
info.leaderMutex.Lock()
|
||||
updateShardsWithRoundRobin(info.shardLeaders)
|
||||
|
||||
m.mu.Lock()
|
||||
m.collInfo[collectionName].shardLeaders = shards
|
||||
m.mu.Unlock()
|
||||
shards := info.CloneShardLeaders()
|
||||
info.leaderMutex.Unlock()
|
||||
return shards, nil
|
||||
}
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
||||
|
@ -612,13 +625,16 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
shards = updateShardsWithRoundRobin(shards)
|
||||
// manipulate info in map, get map returns a copy of the information
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
info = m.collInfo[collectionName]
|
||||
// lock leader
|
||||
info.leaderMutex.Lock()
|
||||
defer info.leaderMutex.Unlock()
|
||||
info.shardLeaders = shards
|
||||
|
||||
m.mu.Lock()
|
||||
m.collInfo[collectionName].shardLeaders = shards
|
||||
m.mu.Unlock()
|
||||
|
||||
return shards, nil
|
||||
return info.CloneShardLeaders(), nil
|
||||
}
|
||||
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
|
||||
|
|
|
@ -49,8 +49,7 @@ func (q queryNode) String() string {
|
|||
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
|
||||
}
|
||||
|
||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {
|
||||
|
||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) {
|
||||
for channelID, leaders := range shardsLeaders {
|
||||
if len(leaders) <= 1 {
|
||||
continue
|
||||
|
@ -58,8 +57,6 @@ func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string
|
|||
|
||||
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
|
||||
}
|
||||
|
||||
return shardsLeaders
|
||||
}
|
||||
|
||||
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
)
|
||||
|
||||
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
||||
in := map[string][]queryNode{
|
||||
list := map[string][]queryNode{
|
||||
"channel-1": {
|
||||
{1, "addr1"},
|
||||
{2, "addr2"},
|
||||
|
@ -26,12 +26,12 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
out := updateShardsWithRoundRobin(in)
|
||||
updateShardsWithRoundRobin(list)
|
||||
|
||||
assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
|
||||
assert.Equal(t, "addr2", out["channel-1"][0].address)
|
||||
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
|
||||
assert.Equal(t, "addr21", out["channel-2"][0].address)
|
||||
assert.Equal(t, int64(2), list["channel-1"][0].nodeID)
|
||||
assert.Equal(t, "addr2", list["channel-1"][0].address)
|
||||
assert.Equal(t, int64(21), list["channel-2"][0].nodeID)
|
||||
assert.Equal(t, "addr21", list["channel-2"][0].address)
|
||||
|
||||
t.Run("check print", func(t *testing.T) {
|
||||
qns := []queryNode{
|
||||
|
|
Loading…
Reference in New Issue