From 0817802db8f45c95c934bdb8ccdb7939ef5b9787 Mon Sep 17 00:00:00 2001 From: SimFG Date: Fri, 5 Jul 2024 15:11:31 +0800 Subject: [PATCH] enhance: use the key lock and concurrent map in the msg dispatcher client (#34278) /kind improvement Signed-off-by: SimFG --- pkg/mq/msgdispatcher/client.go | 46 ++++++++++++++++------------- pkg/mq/msgdispatcher/client_test.go | 4 +-- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 0f6effc932..95e5ff1730 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -18,7 +18,6 @@ package msgdispatcher import ( "context" - "sync" "go.uber.org/zap" @@ -27,6 +26,8 @@ import ( "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/lock" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ( @@ -46,17 +47,18 @@ var _ Client = (*client)(nil) type client struct { role string nodeID int64 - managers map[string]DispatcherManager - managerMut sync.Mutex + managers *typeutil.ConcurrentMap[string, DispatcherManager] + managerMut *lock.KeyLock[string] factory msgstream.Factory } func NewClient(factory msgstream.Factory, role string, nodeID int64) Client { return &client{ - role: role, - nodeID: nodeID, - factory: factory, - managers: make(map[string]DispatcherManager), + role: role, + nodeID: nodeID, + factory: factory, + managers: typeutil.NewConcurrentMap[string, DispatcherManager](), + managerMut: lock.NewKeyLock[string](), } } @@ -64,20 +66,20 @@ func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) pchannel := funcutil.ToPhysicalChannel(vchannel) - c.managerMut.Lock() - defer c.managerMut.Unlock() + c.managerMut.Lock(pchannel) + defer c.managerMut.Unlock(pchannel) var manager DispatcherManager - manager, ok := c.managers[pchannel] + manager, ok := c.managers.Get(pchannel) if !ok { manager = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory) - c.managers[pchannel] = manager + c.managers.Insert(pchannel, manager) go manager.Run() } ch, err := manager.Add(ctx, vchannel, pos, subPos) if err != nil { if manager.Num() == 0 { manager.Close() - delete(c.managers, pchannel) + c.managers.Remove(pchannel) } log.Error("register failed", zap.Error(err)) return nil, err @@ -88,13 +90,13 @@ func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos func (c *client) Deregister(vchannel string) { pchannel := funcutil.ToPhysicalChannel(vchannel) - c.managerMut.Lock() - defer c.managerMut.Unlock() - if manager, ok := c.managers[pchannel]; ok { + c.managerMut.Lock(pchannel) + defer c.managerMut.Unlock(pchannel) + if manager, ok := c.managers.Get(pchannel); ok { manager.Remove(vchannel) if manager.Num() == 0 { manager.Close() - delete(c.managers, pchannel) + c.managers.Remove(pchannel) } log.Info("deregister done", zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) @@ -104,12 +106,14 @@ func (c *client) Deregister(vchannel string) { func (c *client) Close() { log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID)) - c.managerMut.Lock() - defer c.managerMut.Unlock() - for pchannel, manager := range c.managers { + + c.managers.Range(func(pchannel string, manager DispatcherManager) bool { + c.managerMut.Lock(pchannel) + defer c.managerMut.Unlock(pchannel) log.Info("close manager", zap.String("channel", pchannel)) - delete(c.managers, pchannel) + c.managers.Remove(pchannel) manager.Close() - } + return true + }) log.Info("dispatcher client closed") } diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 255de91969..707e0becfd 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -79,8 +79,6 @@ func TestClient_Concurrency(t *testing.T) { expected := int(total - deregisterCount.Load()) c := client1.(*client) - c.managerMut.Lock() - n := len(c.managers) - c.managerMut.Unlock() + n := c.managers.Len() assert.Equal(t, expected, n) }