diff --git a/internal/mq/msgstream/mq_msgstream.go b/internal/mq/msgstream/mq_msgstream.go index 89971801f7..709e0bc968 100644 --- a/internal/mq/msgstream/mq_msgstream.go +++ b/internal/mq/msgstream/mq_msgstream.go @@ -55,7 +55,6 @@ type mqMsgStream struct { bufSize int64 producerLock *sync.Mutex consumerLock *sync.Mutex - readerLock *sync.Mutex closed int32 onceChan sync.Once } @@ -88,7 +87,6 @@ func NewMqMsgStream(ctx context.Context, streamCancel: streamCancel, producerLock: &sync.Mutex{}, consumerLock: &sync.Mutex{}, - readerLock: &sync.Mutex{}, wait: &sync.WaitGroup{}, closed: 0, } @@ -185,11 +183,8 @@ func (ms *mqMsgStream) Start() { } func (ms *mqMsgStream) Close() { - ms.streamCancel() - ms.readerLock.Lock() ms.wait.Wait() - ms.readerLock.Unlock() for _, producer := range ms.producers { if producer != nil { @@ -515,7 +510,11 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { StartPositions: []*internalpb.MsgPosition{tsMsg.Position()}, EndPositions: []*internalpb.MsgPosition{tsMsg.Position()}, } - ms.receiveBuf <- &msgPack + select { + case ms.receiveBuf <- &msgPack: + case <-ms.ctx.Done(): + return + } sp.Finish() } @@ -525,9 +524,7 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { func (ms *mqMsgStream) Chan() <-chan *MsgPack { ms.onceChan.Do(func() { for _, c := range ms.consumers { - ms.readerLock.Lock() ms.wait.Add(1) - ms.readerLock.Unlock() go ms.receiveMsg(c) } }) @@ -760,7 +757,11 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { } //log.Debug("send msg pack", zap.Int("len", len(msgPack.Msgs)), zap.Uint64("currTs", currTs)) - ms.receiveBuf <- &msgPack + select { + case ms.receiveBuf <- &msgPack: + case <-ms.ctx.Done(): + return + } ms.lastTimeStamp = currTs } } @@ -925,9 +926,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { func (ms *MqTtMsgStream) Chan() <-chan *MsgPack { ms.onceChan.Do(func() { if ms.consumers != nil { - ms.readerLock.Lock() ms.wait.Add(1) - ms.readerLock.Unlock() go ms.bufMsgPackToChannel() } })