Update Seek interface (#5492)

* update Seek

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update Seek for mqTtMsgStream

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* use Retry in Seek

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* fix static-check

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/5498/head
Cai Yudong 2021-05-29 23:21:34 +08:00 committed by GitHub
parent 67c4c915b7
commit b414800d49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 73 additions and 80 deletions

View File

@ -347,7 +347,7 @@ func (s *Server) startStatsChannel(ctx context.Context) {
// try to restore last processed pos
pos, err := s.loadStreamLastPos(streamTypeStats)
if err == nil {
err = statsStream.Seek(pos)
err = statsStream.Seek([]*internalpb.MsgPosition{pos})
if err != nil {
log.Error("Failed to seek to last pos for statsStream",
zap.String("StatisChanName", Params.StatisticsChannelName),
@ -403,7 +403,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) {
// try to restore last processed pos
pos, err := s.loadStreamLastPos(streamTypeFlush)
if err == nil {
err = flushStream.Seek(pos)
err = flushStream.Seek([]*internalpb.MsgPosition{pos})
if err != nil {
log.Error("Failed to seek to last pos for segment flush Stream",
zap.String("SegInfoChannelName", Params.SegmentInfoChannelName),

View File

@ -203,6 +203,6 @@ func (mms *MemMsgStream) Chan() <-chan *MsgPack {
return mms.receiveBuf
}
func (mms *MemMsgStream) Seek(offset *MsgPosition) error {
func (mms *MemMsgStream) Seek(offset []*MsgPosition) error {
return errors.New("MemMsgStream seek not implemented")
}

View File

@ -14,6 +14,7 @@ package msgstream
import (
"context"
"errors"
"fmt"
"path/filepath"
"sync"
"time"
@ -351,9 +352,12 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack {
return ms.receiveBuf
}
func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error {
if _, ok := ms.consumers[mp.ChannelName]; ok {
consumer := ms.consumers[mp.ChannelName]
func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
for _, mp := range msgPositions {
consumer, ok := ms.consumers[mp.ChannelName]
if !ok {
return fmt.Errorf("channel %s not subscribed", mp.ChannelName)
}
messageID, err := ms.client.BytesToMsgID(mp.MsgID)
if err != nil {
return err
@ -362,10 +366,8 @@ func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error {
if err != nil {
return err
}
return nil
}
return errors.New("msgStream seek fail")
return nil
}
type MqTtMsgStream struct {
@ -661,28 +663,20 @@ func checkTimeTickMsg(msg map[mqclient.Consumer]Timestamp,
return 0, false
}
func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
if len(mp.MsgID) == 0 {
return errors.New("when msgID's length equal to 0, please use AsConsumer interface")
}
// Seek to the specified position
func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
var consumer mqclient.Consumer
var mp *MsgPosition
var err error
var hasWatched bool
seekChannel := mp.ChannelName
subName := mp.MsgGroup
ms.consumerLock.Lock()
defer ms.consumerLock.Unlock()
consumer, hasWatched = ms.consumers[seekChannel]
if hasWatched {
return errors.New("the channel should has not been subscribed")
}
fn := func() error {
if _, ok := ms.consumers[mp.ChannelName]; ok {
return fmt.Errorf("the channel should not been subscribed")
}
receiveChannel := make(chan mqclient.ConsumerMessage, ms.bufSize)
consumer, err = ms.client.Subscribe(mqclient.ConsumerOptions{
Topic: seekChannel,
SubscriptionName: subName,
Topic: mp.ChannelName,
SubscriptionName: mp.MsgGroup,
SubscriptionInitialPosition: mqclient.SubscriptionPositionEarliest,
Type: mqclient.KeyShared,
MessageChannel: receiveChannel,
@ -691,70 +685,74 @@ func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
return err
}
if consumer == nil {
err = errors.New("consumer is nil")
log.Debug("subscribe error", zap.String("error = ", err.Error()))
return err
return fmt.Errorf("consumer is nil")
}
seekMsgID, err := ms.client.BytesToMsgID(mp.MsgID)
if err != nil {
log.Debug("convert messageID error", zap.String("error = ", err.Error()))
return err
}
err = consumer.Seek(seekMsgID)
if err != nil {
log.Debug("seek error ", zap.String("error = ", err.Error()))
return err
}
return nil
}
err = Retry(20, time.Millisecond*200, fn)
if err != nil {
errMsg := "Failed to seek, error = " + err.Error()
panic(errMsg)
}
ms.addConsumer(consumer, seekChannel)
//TODO: May cause problem
//if len(consumer.Chan()) == 0 {
// return nil
//}
ms.consumerLock.Lock()
defer ms.consumerLock.Unlock()
for {
select {
case <-ms.ctx.Done():
return nil
case msg, ok := <-consumer.Chan():
if !ok {
return errors.New("consumer closed")
}
consumer.Ack(msg)
for idx := range msgPositions {
mp = msgPositions[idx]
if len(mp.MsgID) == 0 {
return fmt.Errorf("when msgID's length equal to 0, please use AsConsumer interface")
}
headerMsg := commonpb.MsgHeader{}
err := proto.Unmarshal(msg.Payload(), &headerMsg)
if err != nil {
log.Error("Failed to unmarshal message header", zap.Error(err))
}
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
if err != nil {
log.Error("Failed to unmarshal tsMsg", zap.Error(err))
}
if tsMsg.Type() == commonpb.MsgType_TimeTick {
if tsMsg.BeginTs() >= mp.Timestamp {
return nil
if err = Retry(20, time.Millisecond*200, fn); err != nil {
return fmt.Errorf("Failed to seek, error %s", err.Error())
}
ms.addConsumer(consumer, mp.ChannelName)
//TODO: May cause problem
//if len(consumer.Chan()) == 0 {
// return nil
//}
runLoop := true
for runLoop {
select {
case <-ms.ctx.Done():
return nil
case msg, ok := <-consumer.Chan():
if !ok {
return fmt.Errorf("consumer closed")
}
consumer.Ack(msg)
headerMsg := commonpb.MsgHeader{}
err := proto.Unmarshal(msg.Payload(), &headerMsg)
if err != nil {
return fmt.Errorf("Failed to unmarshal message header, err %s", err.Error())
}
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
if err != nil {
return fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error())
}
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
runLoop = false
break
} else if tsMsg.BeginTs() > mp.Timestamp {
tsMsg.SetPosition(&MsgPosition{
ChannelName: filepath.Base(msg.Topic()),
MsgID: msg.ID().Serialize(),
})
ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg)
}
continue
}
if tsMsg.BeginTs() > mp.Timestamp {
tsMsg.SetPosition(&MsgPosition{
ChannelName: filepath.Base(msg.Topic()),
MsgID: msg.ID().Serialize(),
})
ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg)
}
}
}
return nil
}
//TODO test InMemMsgStream

View File

@ -246,13 +246,8 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi
factory := ProtoUDFactory{}
pulsarClient, _ := mqclient.NewPulsarClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
//outputStream.AsConsumer(consumerChannels, consumerSubName)
for _, pos := range positions {
pos.MsgGroup = funcutil.RandomString(4)
outputStream.Seek(pos)
}
outputStream.Seek(positions)
outputStream.Start()
//outputStream.Start()
return outputStream
}

View File

@ -45,7 +45,7 @@ type MsgStream interface {
Produce(*MsgPack) error
Broadcast(*MsgPack) error
Consume() *MsgPack
Seek(offset *MsgPosition) error
Seek(offset []*MsgPosition) error
}
type Factory interface {

View File

@ -90,7 +90,7 @@ func (ms *SimpleMsgStream) Consume() *MsgPack {
return <-ms.msgChan
}
func (ms *SimpleMsgStream) Seek(offset *MsgPosition) error {
func (ms *SimpleMsgStream) Seek(offset []*MsgPosition) error {
return nil
}

View File

@ -131,7 +131,7 @@ func (dsService *dataSyncService) initNodes() {
}
func (dsService *dataSyncService) seekSegment(position *internalpb.MsgPosition) error {
err := dsService.dmStream.Seek(position)
err := dsService.dmStream.Seek([]*internalpb.MsgPosition{position})
if err != nil {
return err
}

View File

@ -157,7 +157,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
for _, pos := range toSeekInfo {
err := ds.dmStream.Seek(pos)
err := ds.dmStream.Seek([]*internalpb.MsgPosition{pos})
if err != nil {
errMsg := "msgStream seek error :" + err.Error()
log.Error(errMsg)