Remove Consume() API from mqstream interface (#15886)

Signed-off-by: yun.zhang <yun.zhang@zilliz.com>
pull/15904/head
jaime 2022-03-11 20:09:59 +08:00 committed by GitHub
parent cc0be91ae1
commit 29975a7a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 247 additions and 213 deletions

View File

@ -483,29 +483,29 @@ func (s *Server) handleDataNodeTimetickMsgstream(ctx context.Context, ttMsgStrea
case <-ctx.Done():
log.Info("DataNode timetick loop shutdown")
return
default:
}
msgPack := ttMsgStream.Consume()
if msgPack == nil {
log.Info("receive nil timetick msg and shutdown timetick channel")
return
}
for _, msg := range msgPack.Msgs {
ttMsg, ok := msg.(*msgstream.DataNodeTtMsg)
if !ok {
log.Warn("receive unexpected msg type from tt channel")
continue
}
if enableTtChecker {
checker.Check()
case msgPack, ok := <-ttMsgStream.Chan():
if !ok || msgPack == nil || len(msgPack.Msgs) == 0 {
log.Info("receive nil timetick msg and shutdown timetick channel")
return
}
if err := s.handleTimetickMessage(ctx, ttMsg); err != nil {
log.Error("failed to handle timetick message", zap.Error(err))
continue
for _, msg := range msgPack.Msgs {
ttMsg, ok := msg.(*msgstream.DataNodeTtMsg)
if !ok {
log.Warn("receive unexpected msg type from tt channel")
continue
}
if enableTtChecker {
checker.Check()
}
if err := s.handleTimetickMessage(ctx, ttMsg); err != nil {
log.Error("failed to handle timetick message", zap.Error(err))
continue
}
}
s.helper.eventAfterHandleDataNodeTt()
}
s.helper.eventAfterHandleDataNodeTt()
}
}

View File

@ -2257,7 +2257,13 @@ func TestIssue15659(t *testing.T) {
},
}
ms := &MockClosePanicMsgstream{}
ms.On("Consume").Return(&msgstream.MsgPack{})
msgChan := make(chan *msgstream.MsgPack)
go func() {
msgChan <- &msgstream.MsgPack{}
}()
ms.On("Chan").Return(msgChan)
ch := make(chan struct{})
go func() {
assert.NotPanics(t, func() {
@ -2279,9 +2285,9 @@ func (ms *MockClosePanicMsgstream) Close() {
panic("mocked close panic")
}
func (ms *MockClosePanicMsgstream) Consume() *msgstream.MsgPack {
func (ms *MockClosePanicMsgstream) Chan() <-chan *msgstream.MsgPack {
args := ms.Called()
return args.Get(0).(*msgstream.MsgPack)
return args.Get(0).(chan *msgstream.MsgPack)
}
func newTestServer(t *testing.T, receiveCh chan interface{}, opts ...Option) *Server {

View File

@ -91,9 +91,6 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) error {
func (mtm *mockTtMsgStream) BroadcastMark(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}
func (mtm *mockTtMsgStream) Consume() *msgstream.MsgPack {
return nil
}
func (mtm *mockTtMsgStream) Seek(offset []*internalpb.MsgPosition) error {
return nil
}

View File

@ -475,22 +475,6 @@ func (ms *mqMsgStream) BroadcastMark(msgPack *MsgPack) (map[string][]MessageID,
return ids, nil
}
func (ms *mqMsgStream) Consume() *MsgPack {
for {
select {
case <-ms.ctx.Done():
//log.Debug("context closed")
return nil
case cm, ok := <-ms.receiveBuf:
if !ok {
log.Debug("buf chan closed")
return nil
}
return cm
}
}
}
func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqwrapper.Message) (TsMsg, error) {
header := commonpb.MsgHeader{}
if msg.Payload() == nil {

View File

@ -302,7 +302,7 @@ func TestMqMsgStream_Consume(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
msgPack := m.Consume()
msgPack := consumer(ctx, m)
assert.Nil(t, msgPack)
}()
@ -312,6 +312,20 @@ func TestMqMsgStream_Consume(t *testing.T) {
}
}
func consumer(ctx context.Context, mq MsgStream) *MsgPack {
for {
select {
case msgPack, ok := <-mq.Chan():
if !ok {
panic("Should not reach here")
}
return msgPack
case <-ctx.Done():
return nil
}
}
}
func TestMqMsgStream_Chan(t *testing.T) {
f := &fixture{t: t}
parameters := f.setup()
@ -359,18 +373,19 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
producerChannels := []string{c1, c2}
consumerChannels := []string{c1, c2}
consumerSubName := funcutil.RandomString(8)
ctx := context.Background()
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -383,15 +398,15 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
consumerSubName := funcutil.RandomString(8)
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 1))
//msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 3, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -407,13 +422,14 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -428,13 +444,14 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -449,13 +466,14 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -471,13 +489,14 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -493,12 +512,13 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels, repackFunc)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -540,13 +560,14 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
factory := ProtoUDFactory{}
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.Start()
pulsarClient2, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
var output MsgStream = outputStream
@ -554,7 +575,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
err := (*inputStream).Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(output, len(msgPack.Msgs)*2)
receiveMsg(ctx, output, len(msgPack.Msgs)*2)
(*inputStream).Close()
(*outputStream).Close()
}
@ -593,13 +614,14 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
factory := ProtoUDFactory{}
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.Start()
pulsarClient2, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
var output MsgStream = outputStream
@ -607,7 +629,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
err := (*inputStream).Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(output, len(msgPack.Msgs)*1)
receiveMsg(ctx, output, len(msgPack.Msgs)*1)
(*inputStream).Close()
(*outputStream).Close()
}
@ -626,13 +648,15 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_QueryNodeStats, 4))
factory := ProtoUDFactory{}
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.Start()
pulsarClient2, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
var output MsgStream = outputStream
@ -640,7 +664,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
err := (*inputStream).Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(output, len(msgPack.Msgs))
receiveMsg(ctx, output, len(msgPack.Msgs))
(*inputStream).Close()
(*outputStream).Close()
}
@ -661,8 +685,9 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
@ -673,7 +698,7 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
err = inputStream.Broadcast(&msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(outputStream, len(msgPack1.Msgs))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -705,8 +730,9 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
msgPack5 := MsgPack{}
msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
assert.Nil(t, err)
@ -721,28 +747,27 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
err = inputStream.Broadcast(&msgPack5)
assert.Nil(t, err)
o1 := outputStream.Consume()
o2 := outputStream.Consume()
o3 := outputStream.Consume()
o1 := consumer(ctx, outputStream)
o2 := consumer(ctx, outputStream)
o3 := consumer(ctx, outputStream)
t.Log(o1.BeginTs)
t.Log(o2.BeginTs)
t.Log(o3.BeginTs)
outputStream.Close()
outputStream = getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
p1 := outputStream.Consume()
p2 := outputStream.Consume()
p3 := outputStream.Consume()
outputStream2 := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
p1 := consumer(ctx, outputStream2)
p2 := consumer(ctx, outputStream2)
p3 := consumer(ctx, outputStream2)
t.Log(p1.BeginTs)
t.Log(p2.BeginTs)
t.Log(p3.BeginTs)
outputStream.Close()
outputStream2.Close()
assert.Equal(t, o1.BeginTs, p1.BeginTs)
assert.Equal(t, o2.BeginTs, p2.BeginTs)
assert.Equal(t, o3.BeginTs, p3.BeginTs)
}
func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
@ -780,8 +805,9 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
msgPack7 := MsgPack{}
msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
assert.Nil(t, err)
@ -800,7 +826,7 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
err = inputStream.Broadcast(&msgPack7)
assert.Nil(t, err)
receivedMsg := outputStream.Consume()
receivedMsg := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg.Msgs), 2)
assert.Equal(t, receivedMsg.BeginTs, uint64(0))
assert.Equal(t, receivedMsg.EndTs, uint64(5))
@ -808,21 +834,21 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0))
assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5))
receivedMsg2 := outputStream.Consume()
receivedMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg2.Msgs), 1)
assert.Equal(t, receivedMsg2.BeginTs, uint64(5))
assert.Equal(t, receivedMsg2.EndTs, uint64(11))
assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5))
assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11))
receivedMsg3 := outputStream.Consume()
receivedMsg3 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg3.Msgs), 3)
assert.Equal(t, receivedMsg3.BeginTs, uint64(11))
assert.Equal(t, receivedMsg3.EndTs, uint64(15))
assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11))
assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15))
receivedMsg4 := outputStream.Consume()
receivedMsg4 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg4.Msgs), 1)
assert.Equal(t, receivedMsg4.BeginTs, uint64(15))
assert.Equal(t, receivedMsg4.EndTs, uint64(20))
@ -831,30 +857,29 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
outputStream.Close()
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.StartPositions)
seekMsg := outputStream.Consume()
outputStream = getPulsarTtOutputStreamAndSeek(ctx, pulsarAddress, receivedMsg3.StartPositions)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
}
seekMsg2 := outputStream.Consume()
seekMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
}
//outputStream.Close()
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddress, receivedMsg3.EndPositions)
seekMsg = outputStream.Consume()
outputStream2 := getPulsarTtOutputStreamAndSeek(ctx, pulsarAddress, receivedMsg3.EndPositions)
seekMsg = consumer(ctx, outputStream2)
assert.Equal(t, len(seekMsg.Msgs), 1)
for _, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
}
inputStream.Close()
outputStream.Close()
outputStream2.Close()
}
func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
@ -874,8 +899,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
@ -886,7 +912,7 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
err = inputStream.Broadcast(&msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(outputStream, len(msgPack1.Msgs))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
inputStream.Close()
outputStream.Close()
}
@ -968,21 +994,22 @@ func TestStream_PulsarTtMsgStream_1(t *testing.T) {
consumerChannels := []string{c1, c2}
consumerSubName := funcutil.RandomString(8)
inputStream1 := getPulsarInputStream(pulsarAddr, p1Channels)
ctx := context.Background()
inputStream1 := getPulsarInputStream(ctx, pulsarAddr, p1Channels)
msgPacks1 := createRandMsgPacks(3, 10, 10)
assert.Nil(t, sendMsgPacks(inputStream1, msgPacks1))
inputStream2 := getPulsarInputStream(pulsarAddr, p2Channels)
inputStream2 := getPulsarInputStream(ctx, pulsarAddr, p2Channels)
msgPacks2 := createRandMsgPacks(5, 10, 10)
assert.Nil(t, sendMsgPacks(inputStream2, msgPacks2))
// consume msg
outputStream := getPulsarTtOutputStream(pulsarAddr, consumerChannels, consumerSubName)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddr, consumerChannels, consumerSubName)
log.Println("===============receive msg=================")
checkNMsgPack := func(t *testing.T, outputStream MsgStream, num int) int {
rcvMsg := 0
for i := 0; i < num; i++ {
msgPack := outputStream.Consume()
msgPack := consumer(ctx, outputStream)
rcvMsg += len(msgPack.Msgs)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
@ -1030,11 +1057,12 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
consumerChannels := []string{c1, c2}
consumerSubName := funcutil.RandomString(8)
inputStream1 := getPulsarInputStream(pulsarAddr, p1Channels)
ctx := context.Background()
inputStream1 := getPulsarInputStream(ctx, pulsarAddr, p1Channels)
msgPacks1 := createRandMsgPacks(3, 10, 10)
assert.Nil(t, sendMsgPacks(inputStream1, msgPacks1))
inputStream2 := getPulsarInputStream(pulsarAddr, p2Channels)
inputStream2 := getPulsarInputStream(ctx, pulsarAddr, p2Channels)
msgPacks2 := createRandMsgPacks(5, 10, 10)
assert.Nil(t, sendMsgPacks(inputStream2, msgPacks2))
@ -1046,11 +1074,11 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
var outputStream MsgStream
msgCount := len(rcvMsgPacks)
if msgCount == 0 {
outputStream = getPulsarTtOutputStream(pulsarAddr, consumerChannels, consumerSubName)
outputStream = getPulsarTtOutputStream(ctx, pulsarAddr, consumerChannels, consumerSubName)
} else {
outputStream = getPulsarTtOutputStreamAndSeek(pulsarAddr, rcvMsgPacks[msgCount-1].EndPositions)
outputStream = getPulsarTtOutputStreamAndSeek(ctx, pulsarAddr, rcvMsgPacks[msgCount-1].EndPositions)
}
msgPack := outputStream.Consume()
msgPack := consumer(ctx, outputStream)
rcvMsgPacks = append(rcvMsgPacks, msgPack)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
@ -1084,8 +1112,9 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
consumerSubName := funcutil.RandomString(8)
msgPack := &MsgPack{}
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
for i := 0; i < 10; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
@ -1096,7 +1125,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
assert.Nil(t, err)
var seekPosition *internalpb.MsgPosition
for i := 0; i < 10; i++ {
result := outputStream.Consume()
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
@ -1106,13 +1135,13 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(consumerChannels, consumerSubName)
outputStream2.Seek([]*internalpb.MsgPosition{seekPosition})
outputStream2.Start()
for i := 6; i < 10; i++ {
result := outputStream2.Consume()
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
}
outputStream2.Close()
@ -1126,10 +1155,13 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
consumerChannels := []string{c}
msgPack := &MsgPack{}
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
defer inputStream.Close()
outputStream := getPulsarOutputStream(pulsarAddress, consumerChannels, funcutil.RandomString(8))
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, funcutil.RandomString(8))
defer outputStream.Close()
for i := 0; i < 10; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
@ -1139,14 +1171,14 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
assert.Nil(t, err)
var seekPosition *internalpb.MsgPosition
for i := 0; i < 10; i++ {
result := outputStream.Consume()
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
seekPosition = result.EndPositions[0]
}
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8))
defer outputStream2.Close()
messageID, _ := pulsar.DeserializeMessageID(seekPosition.MsgID)
@ -1172,7 +1204,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
}
err = inputStream.Produce(msgPack)
assert.Nil(t, err)
result := outputStream2.Consume()
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
}
@ -1183,7 +1215,8 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
producerChannels := []string{c}
consumerChannels := []string{c}
consumerSubName := funcutil.RandomString(8)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerSubName)
msgPack := &MsgPack{}
for i := 0; i < 10; i++ {
@ -1195,7 +1228,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
assert.Nil(t, err)
var seekPosition *internalpb.MsgPosition
for i := 0; i < 10; i++ {
result := outputStream.Consume()
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
seekPosition = result.EndPositions[0]
}
@ -1203,7 +1236,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
factory := ProtoUDFactory{}
rmqClient2, _ := rmq.NewClientWithDefaultOptions()
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8))
id := common.Endian.Uint64(seekPosition.MsgID) + 10
@ -1229,7 +1262,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
err = inputStream.Produce(msgPack)
assert.Nil(t, err)
result := outputStream2.Consume()
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
Close(rocksdbName, inputStream, outputStream2, etcdKV)
@ -1244,7 +1277,8 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
consumerSubName := funcutil.RandomString(8)
msgPack := &MsgPack{}
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
for i := 0; i < 10; i++ {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
@ -1255,7 +1289,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
assert.Nil(t, err)
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumerWithPosition(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
outputStream2.Start()
@ -1269,9 +1303,11 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
assert.Nil(t, err)
for i := 10; i < 20; i++ {
result := outputStream2.Consume()
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
}
inputStream.Close()
outputStream2.Close()
}
@ -1283,7 +1319,7 @@ func TestStream_MqMsgStream_Reader(t *testing.T) {
readerChannels := []string{c}
msgPack := &MsgPack{}
inputStream := getPulsarInputStream(pulsarAddress, producerChannels)
inputStream := getPulsarInputStream(context.Background(), pulsarAddress, producerChannels)
defer inputStream.Close()
n := 10
@ -1373,14 +1409,15 @@ func Close(rocksdbName string, intputStream, outputStream MsgStream, etcdKV *etc
log.Println(err)
}
func initRmqStream(producerChannels []string,
func initRmqStream(ctx context.Context,
producerChannels []string,
consumerChannels []string,
consumerGroupName string,
opts ...RepackFunc) (MsgStream, MsgStream) {
factory := ProtoUDFactory{}
rmqClient, _ := rmq.NewClientWithDefaultOptions()
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
@ -1389,7 +1426,7 @@ func initRmqStream(producerChannels []string,
var input MsgStream = inputStream
rmqClient2, _ := rmq.NewClientWithDefaultOptions()
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerGroupName)
outputStream.Start()
var output MsgStream = outputStream
@ -1397,14 +1434,15 @@ func initRmqStream(producerChannels []string,
return input, output
}
func initRmqTtStream(producerChannels []string,
func initRmqTtStream(ctx context.Context,
producerChannels []string,
consumerChannels []string,
consumerGroupName string,
opts ...RepackFunc) (MsgStream, MsgStream) {
factory := ProtoUDFactory{}
rmqClient, _ := rmq.NewClientWithDefaultOptions()
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
@ -1413,7 +1451,7 @@ func initRmqTtStream(producerChannels []string,
var input MsgStream = inputStream
rmqClient2, _ := rmq.NewClientWithDefaultOptions()
outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerGroupName)
outputStream.Start()
var output MsgStream = outputStream
@ -1432,11 +1470,12 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
rocksdbName := "/tmp/rocksmq_insert"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerGroupName)
ctx := context.Background()
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(outputStream, len(msgPack.Msgs))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
@ -1457,7 +1496,8 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
rocksdbName := "/tmp/rocksmq_insert_tt"
etcdKV := initRmq(rocksdbName)
inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
@ -1468,7 +1508,7 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
err = inputStream.Broadcast(&msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(outputStream, len(msgPack1.Msgs))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
@ -1509,7 +1549,8 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
msgPack7 := MsgPack{}
msgPack7.Msgs = append(msgPack7.Msgs, getTimeTickMsg(20))
inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName)
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
err := inputStream.Broadcast(&msgPack0)
assert.Nil(t, err)
@ -1528,7 +1569,7 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
err = inputStream.Broadcast(&msgPack7)
assert.Nil(t, err)
receivedMsg := outputStream.Consume()
receivedMsg := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg.Msgs), 2)
assert.Equal(t, receivedMsg.BeginTs, uint64(0))
assert.Equal(t, receivedMsg.EndTs, uint64(5))
@ -1536,21 +1577,21 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, receivedMsg.StartPositions[0].Timestamp, uint64(0))
assert.Equal(t, receivedMsg.EndPositions[0].Timestamp, uint64(5))
receivedMsg2 := outputStream.Consume()
receivedMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg2.Msgs), 1)
assert.Equal(t, receivedMsg2.BeginTs, uint64(5))
assert.Equal(t, receivedMsg2.EndTs, uint64(11))
assert.Equal(t, receivedMsg2.StartPositions[0].Timestamp, uint64(5))
assert.Equal(t, receivedMsg2.EndPositions[0].Timestamp, uint64(11))
receivedMsg3 := outputStream.Consume()
receivedMsg3 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg3.Msgs), 3)
assert.Equal(t, receivedMsg3.BeginTs, uint64(11))
assert.Equal(t, receivedMsg3.EndTs, uint64(15))
assert.Equal(t, receivedMsg3.StartPositions[0].Timestamp, uint64(11))
assert.Equal(t, receivedMsg3.EndPositions[0].Timestamp, uint64(15))
receivedMsg4 := outputStream.Consume()
receivedMsg4 := consumer(ctx, outputStream)
assert.Equal(t, len(receivedMsg4.Msgs), 1)
assert.Equal(t, receivedMsg4.BeginTs, uint64(15))
assert.Equal(t, receivedMsg4.EndTs, uint64(20))
@ -1568,14 +1609,14 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
outputStream.Seek(receivedMsg3.StartPositions)
outputStream.Start()
seekMsg := outputStream.Consume()
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
}
seekMsg2 := outputStream.Consume()
seekMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
@ -1928,10 +1969,10 @@ func getTimeTickMsgPack(reqID UniqueID) *MsgPack {
return &msgPack
}
func getPulsarInputStream(pulsarAddress string, producerChannels []string, opts ...RepackFunc) MsgStream {
func getPulsarInputStream(ctx context.Context, pulsarAddress string, producerChannels []string, opts ...RepackFunc) MsgStream {
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
@ -1940,10 +1981,10 @@ func getPulsarInputStream(pulsarAddress string, producerChannels []string, opts
return inputStream
}
func getPulsarOutputStream(pulsarAddress string, consumerChannels []string, consumerSubName string) MsgStream {
func getPulsarOutputStream(ctx context.Context, pulsarAddress string, consumerChannels []string, consumerSubName string) MsgStream {
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
return outputStream
@ -1957,19 +1998,19 @@ func getPulsarReader(pulsarAddress string, consumerChannels []string) MsgStream
return outputStream
}
func getPulsarTtOutputStream(pulsarAddress string, consumerChannels []string, consumerSubName string) MsgStream {
func getPulsarTtOutputStream(ctx context.Context, pulsarAddress string, consumerChannels []string, consumerSubName string) MsgStream {
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(consumerChannels, consumerSubName)
outputStream.Start()
return outputStream
}
func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPosition) MsgStream {
func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, positions []*MsgPosition) MsgStream {
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
consumerName := []string{}
for _, c := range positions {
consumerName = append(consumerName, c.ChannelName)
@ -1980,20 +2021,27 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi
return outputStream
}
func receiveMsg(outputStream MsgStream, msgCount int) {
func receiveMsg(ctx context.Context, outputStream MsgStream, msgCount int) {
receiveCount := 0
for {
result := outputStream.Consume()
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
log.Println("msg type: ", v.Type(), ", msg value: ", v)
select {
case <-ctx.Done():
return
case result, ok := <-outputStream.Chan():
if !ok || result == nil || len(result.Msgs) == 0 {
return
}
if len(result.Msgs) > 0 {
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
log.Println("msg type: ", v.Type(), ", msg value: ", v)
}
log.Println("================")
}
if receiveCount >= msgCount {
return
}
log.Println("================")
}
if receiveCount >= msgCount {
break
}
}
}
@ -2040,7 +2088,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
outputStream.Start()
inputStream.Produce(getTimeTickMsgPack(1000))
pack := outputStream.Consume()
pack := <-outputStream.Chan()
assert.NotNil(t, pack)
assert.Equal(t, 1, len(pack.Msgs))
assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs())

View File

@ -67,7 +67,6 @@ type MsgStream interface {
ProduceMark(*MsgPack) (map[string][]MessageID, error)
Broadcast(*MsgPack) error
BroadcastMark(*MsgPack) (map[string][]MessageID, error)
Consume() *MsgPack
Next(ctx context.Context, channelName string) (TsMsg, error)
HasNext(channelName string) bool
Seek(offset []*MsgPosition) error

View File

@ -269,6 +269,13 @@ func (ms *simpleMockMsgStream) Close() {
}
func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
if ms.getMsgCount() <= 0 {
ms.msgChan <- nil
return ms.msgChan
}
defer ms.decreaseMsgCount(1)
return ms.msgChan
}
@ -363,16 +370,6 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string {
return nil
}
func (ms *simpleMockMsgStream) Consume() *msgstream.MsgPack {
if ms.getMsgCount() <= 0 {
return nil
}
defer ms.decreaseMsgCount(1)
return <-ms.msgChan
}
func (ms *simpleMockMsgStream) Seek(offset []*msgstream.MsgPosition) error {
return nil
}

View File

@ -1851,7 +1851,12 @@ func TestSearchTask_all(t *testing.T) {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
@ -2198,7 +2203,12 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
@ -2539,7 +2549,12 @@ func TestSearchTask_7803_reduce(t *testing.T) {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
@ -3347,7 +3362,13 @@ func TestQueryTask_all(t *testing.T) {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.RetrieveMsg)
assert.True(t, ok)

View File

@ -1632,24 +1632,6 @@ func initConsumer(ctx context.Context, queryResultChannel Channel) (msgstream.Ms
return stream, nil
}
func consumeSimpleSearchResult(stream msgstream.MsgStream) (*msgstream.SearchResultMsg, error) {
res := stream.Consume()
if len(res.Msgs) != 1 {
err := errors.New("unexpected message length")
return nil, err
}
return res.Msgs[0].(*msgstream.SearchResultMsg), nil
}
func consumeSimpleRetrieveResult(stream msgstream.MsgStream) (*msgstream.RetrieveResultMsg, error) {
res := stream.Consume()
if len(res.Msgs) != 1 {
err := errors.New("unexpected message length")
return nil, err
}
return res.Msgs[0].(*msgstream.RetrieveResultMsg), nil
}
func genSimpleChangeInfo() *querypb.SealedSegmentsChangeInfo {
changeInfo := &querypb.SegmentChangeInfo{
OnlineNodeID: Params.QueryNodeCfg.QueryNodeID,

View File

@ -364,18 +364,15 @@ func (q *queryCollection) consumeQuery() {
case <-q.releaseCtx.Done():
log.Debug("stop queryCollection's receiveQueryMsg", zap.Int64("collectionID", q.collectionID))
return
default:
msgPack := q.queryMsgStream.Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
//msgPackNil := msgPack == nil
//msgPackEmpty := true
//if msgPack != nil {
// msgPackEmpty = len(msgPack.Msgs) <= 0
//}
//log.Debug("consume query message failed", zap.Any("msgPack is Nil", msgPackNil),
// zap.Any("msgPackEmpty", msgPackEmpty))
case msgPack, ok := <-q.queryMsgStream.Chan():
if !ok {
log.Warn("Receive Query Msg from chan failed", zap.Int64("collectionID", q.collectionID))
return
}
if !ok || msgPack == nil || len(msgPack.Msgs) == 0 {
continue
}
for _, msg := range msgPack.Msgs {
switch sm := msg.(type) {
case *msgstream.SearchMsg:

View File

@ -159,7 +159,6 @@ func (ms *FailMsgStream) BroadcastMark(*msgstream.MsgPack) (map[string][]msgstre
}
return nil, nil
}
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) Next(ctx context.Context, channelName string) (msgstream.TsMsg, error) {
return nil, nil
}

View File

@ -62,7 +62,11 @@ func (inNode *InputNode) InStream() msgstream.MsgStream {
// Operate consume a message pack from msgstream and return
func (inNode *InputNode) Operate(in []Msg) []Msg {
msgPack := inNode.inStream.Consume()
msgPack, ok := <-inNode.inStream.Chan()
if !ok {
log.Warn("Receive Msg failed from upstream node", zap.Any("input node", inNode.Name()))
return []Msg{}
}
// TODO: add status
if msgPack == nil {