diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 0cfb69fb89..8eafe90c59 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -82,3 +82,40 @@ func TestClient_Concurrency(t *testing.T) { n := c.managers.Len() assert.Equal(t, expected, n) } + +func TestClientMainDispatcherLeak(t *testing.T) { + client := NewClient(newMockFactory(), typeutil.ProxyRole, 1) + assert.NotNil(t, client) + pchannel := "mock_vchannel_0" + + vchannel1 := fmt.Sprintf("%s_abc_v0", pchannel) //"mock_vchannel_0_abc_v0" + vchannel2 := fmt.Sprintf("%s_abc_v1", pchannel) //"mock_vchannel_0_abc_v0" + _, err := client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + + _, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + + client.Deregister(vchannel2) + client.Deregister(vchannel1) + + assert.NotPanics( + t, func() { + _, err = client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + _, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + }, + ) + + client.Deregister(vchannel1) + client.Deregister(vchannel2) + assert.NotPanics( + t, func() { + _, err = client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + _, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + }, + ) +} diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 33589b14e8..e0d64c34d0 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -126,16 +126,6 @@ func (c *dispatcherManager) Remove(vchannel string) { zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) c.mu.Lock() defer c.mu.Unlock() - if c.mainDispatcher != nil { - c.mainDispatcher.Handle(pause) - c.mainDispatcher.CloseTarget(vchannel) - if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 { - c.mainDispatcher.Handle(terminate) - c.mainDispatcher = nil - } else { - c.mainDispatcher.Handle(resume) - } - } if _, ok := c.soloDispatchers[vchannel]; ok { c.soloDispatchers[vchannel].Handle(terminate) c.soloDispatchers[vchannel].CloseTarget(vchannel) @@ -144,6 +134,18 @@ func (c *dispatcherManager) Remove(vchannel string) { log.Info("remove soloDispatcher done") } c.lagTargets.GetAndRemove(vchannel) + + if c.mainDispatcher != nil { + c.mainDispatcher.Handle(pause) + c.mainDispatcher.CloseTarget(vchannel) + if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 { + c.mainDispatcher.Handle(terminate) + c.mainDispatcher = nil + log.Info("remove mainDispatcher done") + } else { + c.mainDispatcher.Handle(resume) + } + } } func (c *dispatcherManager) NumTarget() int {