From 90ee23df22da9a452aa34c074a54827193c3561f Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 20 May 2022 14:33:57 +0800 Subject: [PATCH] Fix shardLeader cache concurrent access (#17120) Fix write map without mutex control Also GetShards returns a copy of leader list instead of original one Signed-off-by: Congqi Xia --- internal/proxy/meta_cache.go | 36 ++++++++++++++++++++-------- internal/proxy/task_policies.go | 5 +--- internal/proxy/task_policies_test.go | 12 +++++----- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 033ef8abcd..4cf941efef 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -72,11 +72,24 @@ type collectionInfo struct { schema *schemapb.CollectionSchema partInfo map[string]*partitionInfo shardLeaders map[string][]queryNode + leaderMutex sync.Mutex createdTimestamp uint64 createdUtcTimestamp uint64 isLoaded bool } +// CloneShardLeaders returns a copy of shard leaders +// leaderMutex shall be accuired before invoking this method +func (c *collectionInfo) CloneShardLeaders() map[string][]queryNode { + m := make(map[string][]queryNode) + for channel, leaders := range c.shardLeaders { + l := make([]queryNode, len(leaders)) + copy(l, leaders) + m[channel] = l + } + return m +} + type partitionInfo struct { partitionID typeutil.UniqueID createdTimestamp uint64 @@ -584,11 +597,11 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam if withCache { if len(info.shardLeaders) > 0 { - shards := updateShardsWithRoundRobin(info.shardLeaders) + info.leaderMutex.Lock() + updateShardsWithRoundRobin(info.shardLeaders) - m.mu.Lock() - m.collInfo[collectionName].shardLeaders = shards - m.mu.Unlock() + shards := info.CloneShardLeaders() + info.leaderMutex.Unlock() return shards, nil } log.Info("no shard cache for collection, try to get shard leaders from QueryCoord", @@ -612,13 +625,16 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam shards := parseShardLeaderList2QueryNode(resp.GetShards()) - shards = updateShardsWithRoundRobin(shards) + // manipulate info in map, get map returns a copy of the information + m.mu.RLock() + defer m.mu.RUnlock() + info = m.collInfo[collectionName] + // lock leader + info.leaderMutex.Lock() + defer info.leaderMutex.Unlock() + info.shardLeaders = shards - m.mu.Lock() - m.collInfo[collectionName].shardLeaders = shards - m.mu.Unlock() - - return shards, nil + return info.CloneShardLeaders(), nil } func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode { diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index f293adf602..68eadf63ad 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -49,8 +49,7 @@ func (q queryNode) String() string { return fmt.Sprintf("", q.nodeID) } -func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode { - +func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) { for channelID, leaders := range shardsLeaders { if len(leaders) <= 1 { continue @@ -58,8 +57,6 @@ func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string shardsLeaders[channelID] = append(leaders[1:], leaders[0]) } - - return shardsLeaders } func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error { diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go index 568da81936..2be147cc62 100644 --- a/internal/proxy/task_policies_test.go +++ b/internal/proxy/task_policies_test.go @@ -15,7 +15,7 @@ import ( ) func TestUpdateShardsWithRoundRobin(t *testing.T) { - in := map[string][]queryNode{ + list := map[string][]queryNode{ "channel-1": { {1, "addr1"}, {2, "addr2"}, @@ -26,12 +26,12 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) { }, } - out := updateShardsWithRoundRobin(in) + updateShardsWithRoundRobin(list) - assert.Equal(t, int64(2), out["channel-1"][0].nodeID) - assert.Equal(t, "addr2", out["channel-1"][0].address) - assert.Equal(t, int64(21), out["channel-2"][0].nodeID) - assert.Equal(t, "addr21", out["channel-2"][0].address) + assert.Equal(t, int64(2), list["channel-1"][0].nodeID) + assert.Equal(t, "addr2", list["channel-1"][0].address) + assert.Equal(t, int64(21), list["channel-2"][0].nodeID) + assert.Equal(t, "addr21", list["channel-2"][0].address) t.Run("check print", func(t *testing.T) { qns := []queryNode{