mirror of https://github.com/milvus-io/milvus.git
Update Seek interface (#5492)
* update Seek Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update Seek for mqTtMsgStream Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * use Retry in Seek Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix static-check Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/5498/head
parent
67c4c915b7
commit
b414800d49
|
@ -347,7 +347,7 @@ func (s *Server) startStatsChannel(ctx context.Context) {
|
|||
// try to restore last processed pos
|
||||
pos, err := s.loadStreamLastPos(streamTypeStats)
|
||||
if err == nil {
|
||||
err = statsStream.Seek(pos)
|
||||
err = statsStream.Seek([]*internalpb.MsgPosition{pos})
|
||||
if err != nil {
|
||||
log.Error("Failed to seek to last pos for statsStream",
|
||||
zap.String("StatisChanName", Params.StatisticsChannelName),
|
||||
|
@ -403,7 +403,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) {
|
|||
// try to restore last processed pos
|
||||
pos, err := s.loadStreamLastPos(streamTypeFlush)
|
||||
if err == nil {
|
||||
err = flushStream.Seek(pos)
|
||||
err = flushStream.Seek([]*internalpb.MsgPosition{pos})
|
||||
if err != nil {
|
||||
log.Error("Failed to seek to last pos for segment flush Stream",
|
||||
zap.String("SegInfoChannelName", Params.SegmentInfoChannelName),
|
||||
|
|
|
@ -203,6 +203,6 @@ func (mms *MemMsgStream) Chan() <-chan *MsgPack {
|
|||
return mms.receiveBuf
|
||||
}
|
||||
|
||||
func (mms *MemMsgStream) Seek(offset *MsgPosition) error {
|
||||
func (mms *MemMsgStream) Seek(offset []*MsgPosition) error {
|
||||
return errors.New("MemMsgStream seek not implemented")
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ package msgstream
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -351,9 +352,12 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack {
|
|||
return ms.receiveBuf
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error {
|
||||
if _, ok := ms.consumers[mp.ChannelName]; ok {
|
||||
consumer := ms.consumers[mp.ChannelName]
|
||||
func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
|
||||
for _, mp := range msgPositions {
|
||||
consumer, ok := ms.consumers[mp.ChannelName]
|
||||
if !ok {
|
||||
return fmt.Errorf("channel %s not subscribed", mp.ChannelName)
|
||||
}
|
||||
messageID, err := ms.client.BytesToMsgID(mp.MsgID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -362,10 +366,8 @@ func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("msgStream seek fail")
|
||||
return nil
|
||||
}
|
||||
|
||||
type MqTtMsgStream struct {
|
||||
|
@ -661,28 +663,20 @@ func checkTimeTickMsg(msg map[mqclient.Consumer]Timestamp,
|
|||
return 0, false
|
||||
}
|
||||
|
||||
func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
|
||||
if len(mp.MsgID) == 0 {
|
||||
return errors.New("when msgID's length equal to 0, please use AsConsumer interface")
|
||||
}
|
||||
// Seek to the specified position
|
||||
func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
|
||||
var consumer mqclient.Consumer
|
||||
var mp *MsgPosition
|
||||
var err error
|
||||
var hasWatched bool
|
||||
seekChannel := mp.ChannelName
|
||||
subName := mp.MsgGroup
|
||||
ms.consumerLock.Lock()
|
||||
defer ms.consumerLock.Unlock()
|
||||
consumer, hasWatched = ms.consumers[seekChannel]
|
||||
|
||||
if hasWatched {
|
||||
return errors.New("the channel should has not been subscribed")
|
||||
}
|
||||
|
||||
fn := func() error {
|
||||
if _, ok := ms.consumers[mp.ChannelName]; ok {
|
||||
return fmt.Errorf("the channel should not been subscribed")
|
||||
}
|
||||
|
||||
receiveChannel := make(chan mqclient.ConsumerMessage, ms.bufSize)
|
||||
consumer, err = ms.client.Subscribe(mqclient.ConsumerOptions{
|
||||
Topic: seekChannel,
|
||||
SubscriptionName: subName,
|
||||
Topic: mp.ChannelName,
|
||||
SubscriptionName: mp.MsgGroup,
|
||||
SubscriptionInitialPosition: mqclient.SubscriptionPositionEarliest,
|
||||
Type: mqclient.KeyShared,
|
||||
MessageChannel: receiveChannel,
|
||||
|
@ -691,70 +685,74 @@ func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
|
|||
return err
|
||||
}
|
||||
if consumer == nil {
|
||||
err = errors.New("consumer is nil")
|
||||
log.Debug("subscribe error", zap.String("error = ", err.Error()))
|
||||
return err
|
||||
return fmt.Errorf("consumer is nil")
|
||||
}
|
||||
|
||||
seekMsgID, err := ms.client.BytesToMsgID(mp.MsgID)
|
||||
if err != nil {
|
||||
log.Debug("convert messageID error", zap.String("error = ", err.Error()))
|
||||
return err
|
||||
}
|
||||
err = consumer.Seek(seekMsgID)
|
||||
if err != nil {
|
||||
log.Debug("seek error ", zap.String("error = ", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
err = Retry(20, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
errMsg := "Failed to seek, error = " + err.Error()
|
||||
panic(errMsg)
|
||||
}
|
||||
ms.addConsumer(consumer, seekChannel)
|
||||
|
||||
//TODO: May cause problem
|
||||
//if len(consumer.Chan()) == 0 {
|
||||
// return nil
|
||||
//}
|
||||
ms.consumerLock.Lock()
|
||||
defer ms.consumerLock.Unlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ms.ctx.Done():
|
||||
return nil
|
||||
case msg, ok := <-consumer.Chan():
|
||||
if !ok {
|
||||
return errors.New("consumer closed")
|
||||
}
|
||||
consumer.Ack(msg)
|
||||
for idx := range msgPositions {
|
||||
mp = msgPositions[idx]
|
||||
if len(mp.MsgID) == 0 {
|
||||
return fmt.Errorf("when msgID's length equal to 0, please use AsConsumer interface")
|
||||
}
|
||||
|
||||
headerMsg := commonpb.MsgHeader{}
|
||||
err := proto.Unmarshal(msg.Payload(), &headerMsg)
|
||||
if err != nil {
|
||||
log.Error("Failed to unmarshal message header", zap.Error(err))
|
||||
}
|
||||
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
|
||||
if err != nil {
|
||||
log.Error("Failed to unmarshal tsMsg", zap.Error(err))
|
||||
}
|
||||
if tsMsg.Type() == commonpb.MsgType_TimeTick {
|
||||
if tsMsg.BeginTs() >= mp.Timestamp {
|
||||
return nil
|
||||
if err = Retry(20, time.Millisecond*200, fn); err != nil {
|
||||
return fmt.Errorf("Failed to seek, error %s", err.Error())
|
||||
}
|
||||
ms.addConsumer(consumer, mp.ChannelName)
|
||||
|
||||
//TODO: May cause problem
|
||||
//if len(consumer.Chan()) == 0 {
|
||||
// return nil
|
||||
//}
|
||||
|
||||
runLoop := true
|
||||
for runLoop {
|
||||
select {
|
||||
case <-ms.ctx.Done():
|
||||
return nil
|
||||
case msg, ok := <-consumer.Chan():
|
||||
if !ok {
|
||||
return fmt.Errorf("consumer closed")
|
||||
}
|
||||
consumer.Ack(msg)
|
||||
|
||||
headerMsg := commonpb.MsgHeader{}
|
||||
err := proto.Unmarshal(msg.Payload(), &headerMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to unmarshal message header, err %s", err.Error())
|
||||
}
|
||||
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error())
|
||||
}
|
||||
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
|
||||
runLoop = false
|
||||
break
|
||||
} else if tsMsg.BeginTs() > mp.Timestamp {
|
||||
tsMsg.SetPosition(&MsgPosition{
|
||||
ChannelName: filepath.Base(msg.Topic()),
|
||||
MsgID: msg.ID().Serialize(),
|
||||
})
|
||||
ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tsMsg.BeginTs() > mp.Timestamp {
|
||||
tsMsg.SetPosition(&MsgPosition{
|
||||
ChannelName: filepath.Base(msg.Topic()),
|
||||
MsgID: msg.ID().Serialize(),
|
||||
})
|
||||
ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//TODO test InMemMsgStream
|
||||
|
|
|
@ -246,13 +246,8 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi
|
|||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := mqclient.NewPulsarClient(pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
//outputStream.AsConsumer(consumerChannels, consumerSubName)
|
||||
for _, pos := range positions {
|
||||
pos.MsgGroup = funcutil.RandomString(4)
|
||||
outputStream.Seek(pos)
|
||||
}
|
||||
outputStream.Seek(positions)
|
||||
outputStream.Start()
|
||||
//outputStream.Start()
|
||||
return outputStream
|
||||
}
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ type MsgStream interface {
|
|||
Produce(*MsgPack) error
|
||||
Broadcast(*MsgPack) error
|
||||
Consume() *MsgPack
|
||||
Seek(offset *MsgPosition) error
|
||||
Seek(offset []*MsgPosition) error
|
||||
}
|
||||
|
||||
type Factory interface {
|
||||
|
|
|
@ -90,7 +90,7 @@ func (ms *SimpleMsgStream) Consume() *MsgPack {
|
|||
return <-ms.msgChan
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) Seek(offset *MsgPosition) error {
|
||||
func (ms *SimpleMsgStream) Seek(offset []*MsgPosition) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ func (dsService *dataSyncService) initNodes() {
|
|||
}
|
||||
|
||||
func (dsService *dataSyncService) seekSegment(position *internalpb.MsgPosition) error {
|
||||
err := dsService.dmStream.Seek(position)
|
||||
err := dsService.dmStream.Seek([]*internalpb.MsgPosition{position})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -157,7 +157,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
|
|||
|
||||
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
|
||||
for _, pos := range toSeekInfo {
|
||||
err := ds.dmStream.Seek(pos)
|
||||
err := ds.dmStream.Seek([]*internalpb.MsgPosition{pos})
|
||||
if err != nil {
|
||||
errMsg := "msgStream seek error :" + err.Error()
|
||||
log.Error(errMsg)
|
||||
|
|
Loading…
Reference in New Issue