enhance: refine pular related mq interfaces (#38007)

issue: #35917 
Refines the pulsar-related mq APIs to allow the ctx to be passed down

Signed-off-by: tinswzy <zhenyuan.wei@zilliz.com>
pull/37378/merge
tinswzy 2024-12-04 20:50:39 +08:00 committed by GitHub
parent 73aa95f596
commit 5768dbbb5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 380 additions and 367 deletions

View File

@ -406,7 +406,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
return
}
segment := t.meta.GetHealthySegment(t.meta.ctx, signal.segmentID)
segment := t.meta.GetHealthySegment(context.TODO(), signal.segmentID)
if segment == nil {
log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID))
return

View File

@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
return make(chan *msgstream.MsgPack, 100)
}
func (mtm *mockTtMsgStream) AsProducer(channels []string) {}
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil
@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
return make([]string, 0)
}
func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
func (mtm *mockTtMsgStream) Produce(context.Context, *msgstream.MsgPack) error {
return nil
}
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (mtm *mockTtMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return nil, nil
}

View File

@ -39,7 +39,7 @@ import (
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error)
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID)
removeAllDMLStream()
}
@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
return ok && streamInfos.stream != nil
}
func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error
@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
return nil, err
}
stream.AsProducer(pchans)
stream.AsProducer(ctx, pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) {
// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr
return nil, err
}
stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea
// getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}
return mgr.createMsgStream(collectionID)
return mgr.createMsgStream(ctx, collectionID)
}
// removeStream remove the corresponding stream of the specified collection. Idempotent.
@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID)
}
func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID)
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
}
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {

View File

@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) {
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
}()
@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(100)
stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getOrCreateStream(100)
_, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err)
})
@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})

View File

@ -6323,7 +6323,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(msgPack)
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil

View File

@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("create database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("drop database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
@ -1496,13 +1496,13 @@ func TestProxy_ReplicateMessage(t *testing.T) {
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
}, nil)
@ -1581,7 +1581,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast"))
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
@ -1590,7 +1590,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
}
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {},
}, nil)
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)

View File

@ -3,6 +3,8 @@
package proxy
import (
context "context"
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
mock "github.com/stretchr/testify/mock"
)
@ -78,9 +80,9 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
return _c
}
// getOrCreateDmlStream provides a mock function with given fields: collectionID
func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.MsgStream, error) {
ret := _m.Called(collectionID)
// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for getOrCreateDmlStream")
@ -88,19 +90,19 @@ func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.M
var r0 msgstream.MsgStream
var r1 error
if rf, ok := ret.Get(0).(func(int64) (msgstream.MsgStream, error)); ok {
return rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(int64) msgstream.MsgStream); ok {
r0 = rf(collectionID)
if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(msgstream.MsgStream)
}
}
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(collectionID)
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
@ -114,14 +116,15 @@ type MockChannelsMgr_getOrCreateDmlStream_Call struct {
}
// getOrCreateDmlStream is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", collectionID)}
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
@ -131,7 +134,7 @@ func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStr
return _c
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(run)
return _c
}

View File

@ -16,7 +16,7 @@ type mockMsgStream struct {
enableProduce func(bool)
}
func (m *mockMsgStream) AsProducer(producers []string) {
func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) {
if m.asProducer != nil {
m.asProducer(producers)
}

View File

@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
return ms.msgChan
}
func (ms *simpleMockMsgStream) AsProducer(channels []string) {
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
}
func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
@ -283,7 +283,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
ms.increaseMsgCount(-delta)
}
func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
defer ms.increaseMsgCount(1)
ms.msgChan <- pack
@ -291,7 +291,7 @@ func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
return nil
}
func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}

View File

@ -278,7 +278,7 @@ func (node *Proxy) Init() error {
return err
}
node.replicateMsgStream.EnableProduce(true)
node.replicateMsgStream.AsProducer([]string{replicateMsgChannel})
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
if err != nil {

View File

@ -34,15 +34,15 @@ func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, r
return manager
}
func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.NewResourceFunc {
func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
return func() (resource.Resource, error) {
msgStream, err := m.factory.NewMsgStream(m.ctx)
msgStream, err := m.factory.NewMsgStream(ctx)
if err != nil {
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err
}
msgStream.SetRepackFunc(replicatePackFunc)
msgStream.AsProducer([]string{channel})
msgStream.AsProducer(ctx, []string{channel})
msgStream.EnableProduce(true)
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
@ -55,7 +55,7 @@ func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.N
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(channel))
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
if err != nil {
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err

View File

@ -142,7 +142,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
}
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID)
stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
if err != nil {
return err
}
@ -178,7 +178,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
zap.Int64("taskID", dt.ID()),
zap.Duration("prepare duration", dt.tr.RecordSpan()))
err = stream.Produce(msgPack)
err = stream.Produce(ctx, msgPack)
if err != nil {
return err
}

View File

@ -161,7 +161,7 @@ func TestDeleteTask_Execute(t *testing.T) {
},
}
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(nil, errors.New("mock error"))
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background()))
})
@ -190,7 +190,7 @@ func TestDeleteTask_Execute(t *testing.T) {
primaryKeys: pk,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
assert.Error(t, dt.Execute(context.Background()))
})
@ -226,8 +226,8 @@ func TestDeleteTask_Execute(t *testing.T) {
primaryKeys: pk,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error"))
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background()))
})
}
@ -535,9 +535,9 @@ func TestDeleteRunner_Run(t *testing.T) {
},
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything).Return(fmt.Errorf("mock error"))
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error"))
assert.Error(t, dr.Run(context.Background()))
assert.Equal(t, int64(0), dr.result.DeleteCnt)
@ -644,9 +644,9 @@ func TestDeleteRunner_Run(t *testing.T) {
},
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -768,7 +768,7 @@ func TestDeleteRunner_Run(t *testing.T) {
},
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -792,7 +792,7 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil)
return client
}, nil)
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error"))
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
assert.Error(t, dr.Run(ctx))
assert.Equal(t, int64(0), dr.result.DeleteCnt)
@ -830,7 +830,7 @@ func TestDeleteRunner_Run(t *testing.T) {
},
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -854,7 +854,7 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil)
return client
}, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt)
@ -911,7 +911,7 @@ func TestDeleteRunner_Run(t *testing.T) {
},
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -936,7 +936,7 @@ func TestDeleteRunner_Run(t *testing.T) {
return client
}, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt)
})

View File

@ -243,7 +243,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
it.insertMsg.CollectionID = collID
getCacheDur := tr.RecordSpan()
stream, err := it.chMgr.getOrCreateDmlStream(collID)
stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
@ -280,7 +280,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
log.Debug("assign segmentID for insert data success",
zap.Duration("assign segmentID duration", assignSegmentIDDur))
err = stream.Produce(msgPack)
err = stream.Produce(ctx, msgPack)
if err != nil {
log.Warn("fail to produce insert msg", zap.Error(err))
it.result.Status = merr.Status(err)

View File

@ -1755,7 +1755,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -2004,7 +2004,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -3460,7 +3460,7 @@ func TestPartitionKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)

View File

@ -393,7 +393,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP
zap.Int64("collectionID", collID))
getCacheDur := tr.RecordSpan()
_, err = it.chMgr.getOrCreateDmlStream(collID)
_, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
@ -526,7 +526,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID)
stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
if err != nil {
return err
}
@ -547,7 +547,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
}
tr.RecordSpan()
err = stream.Produce(msgPack)
err = stream.Produce(ctx, msgPack)
if err != nil {
it.result.Status = merr.Status(err)
return err

View File

@ -1985,7 +1985,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
EndTs: ts,
Msgs: []msgstream.TsMsg{tsMsg},
}
msgErr := replicateMsgStream.Produce(msgPack)
msgErr := replicateMsgStream.Produce(ctx, msgPack)
// ignore the error if the msg stream failed to produce the msg,
// because it can be manually fixed in this error
if msgErr != nil {

View File

@ -2430,7 +2430,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
})
t.Run("produce fail", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything).Return(errors.New("produce error")).Once()
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
@ -2444,7 +2444,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
})
t.Run("normal case", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything).Return(nil)
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})

View File

@ -188,7 +188,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
d.checkPreCreatedTopic(ctx, factory, name)
}
ms.AsProducer([]string{name})
ms.AsProducer(ctx, []string{name})
dms := &dmlMsgStream{
ms: ms,
refcnt: 0,
@ -291,7 +291,7 @@ func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) err
dms.mutex.RLock()
if dms.refcnt > 0 {
if _, err := dms.ms.Broadcast(pack); err != nil {
if _, err := dms.ms.Broadcast(d.ctx, pack); err != nil {
log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock()
return err
@ -312,7 +312,7 @@ func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack)
dms.mutex.RLock()
if dms.refcnt > 0 {
ids, err := dms.ms.Broadcast(pack)
ids, err := dms.ms.Broadcast(d.ctx, pack)
if err != nil {
log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock()

View File

@ -277,17 +277,17 @@ type FailMsgStream struct {
errBroadcast bool
}
func (ms *FailMsgStream) Close() {}
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) AsProducer(channels []string) {}
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
func (ms *FailMsgStream) Close() {}
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil
}
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
func (ms *FailMsgStream) Produce(*msgstream.MsgPack) error { return nil }
func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
func (ms *FailMsgStream) Produce(context.Context, *msgstream.MsgPack) error { return nil }
func (ms *FailMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
if ms.errBroadcast {
return nil, errors.New("broadcast error")
}

View File

@ -42,8 +42,8 @@ func TestInputNode(t *testing.T) {
msgPack := generateMsgPack()
produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels)
produceStream.Produce(&msgPack)
produceStream.AsProducer(context.TODO(), channels)
produceStream.Produce(context.TODO(), &msgPack)
nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
@ -84,7 +84,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
msgStream.AsConsumer(context.Background(), channels, "sub", common.SubscriptionPositionEarliest)
produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels)
produceStream.AsProducer(context.TODO(), channels)
closeCh := make(chan struct{})
outputCh := make(chan bool)
@ -110,7 +110,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
defer close(closeCh)
msgPack := generateMsgPack()
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
log.Info("produce empty ttmsg")
<-outputCh
assert.Equal(t, 1, outputCount)
@ -118,7 +118,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
time.Sleep(3 * time.Second)
assert.Equal(t, false, inputNode.skipMode)
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
log.Info("after 3 seconds with no active msg receive, input node will turn on skip mode")
<-outputCh
assert.Equal(t, 2, outputCount)
@ -126,13 +126,13 @@ func Test_InputNodeSkipMode(t *testing.T) {
log.Info("some ttmsg will be skipped in skip mode")
// this msg will be skipped
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
<-outputCh
assert.Equal(t, 2, outputCount)
assert.Equal(t, true, inputNode.skipMode)
// this msg will be consumed
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
<-outputCh
assert.Equal(t, 3, outputCount)
assert.Equal(t, true, inputNode.skipMode)

View File

@ -80,13 +80,13 @@ func TestNodeManager_Start(t *testing.T) {
msgStream.AsConsumer(context.TODO(), channels, "sub", common.SubscriptionPositionEarliest)
produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels)
produceStream.AsProducer(context.TODO(), channels)
msgPack := generateMsgPack()
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
time.Sleep(time.Millisecond * 2)
msgPack = generateMsgPack()
produceStream.Produce(&msgPack)
produceStream.Produce(context.TODO(), &msgPack)
nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")

View File

@ -226,7 +226,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
insNum := rand.Intn(10)
for j := 0; j < insNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
@ -237,7 +237,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
delNum := rand.Intn(2)
for j := 0; j < delNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
@ -247,7 +247,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
// produce random ddl
ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ {
err := suite.producer.Produce(&msgstream.MsgPack{
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)},
})
assert.NoError(suite.T(), err)
@ -257,7 +257,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
}
// produce time tick
ts := uint64(i * 100)
err := suite.producer.Produce(&msgstream.MsgPack{
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
@ -305,7 +305,7 @@ func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
return
case <-ticker.C:
ts := uint64(tt * 1000)
err := suite.producer.Produce(&msgstream.MsgPack{
err := suite.producer.Produce(ctx, &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)

View File

@ -55,7 +55,7 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS
if err != nil {
return nil, err
}
stream.AsProducer([]string{pchannel})
stream.AsProducer(context.TODO(), []string{pchannel})
stream.SetRepackFunc(defaultInsertRepackFunc)
return stream, nil
}

View File

@ -173,11 +173,11 @@ func testTimeTickerAndInsert(t *testing.T, f []Factory) {
defer consumer.Close()
var err error
_, err = producer.Broadcast(&msgPack0)
_, err = producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = producer.Produce(&msgPack1)
err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2)
_, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
@ -210,17 +210,17 @@ func testTimeTickerNoSeek(t *testing.T, f []Factory) {
defer producer.Close()
var err error
_, err = producer.Broadcast(&msgPack0)
_, err = producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = producer.Produce(&msgPack1)
err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2)
_, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = producer.Produce(&msgPack3)
err = producer.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack4)
_, err = producer.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack5)
_, err = producer.Broadcast(ctx, &msgPack5)
assert.NoError(t, err)
o1 := consume(ctx, consumer)
@ -259,7 +259,7 @@ func testSeekToLast(t *testing.T, f []Factory) {
}
// produce test data
err := producer.Produce(msgPack)
err := producer.Produce(ctx, msgPack)
assert.NoError(t, err)
// pick a seekPosition
@ -346,21 +346,21 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
defer producer.Close()
// Send message
_, err := producer.Broadcast(&msgPack0)
_, err := producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = producer.Produce(&msgPack1)
err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2)
_, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = producer.Produce(&msgPack3)
err = producer.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack4)
_, err = producer.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
err = producer.Produce(&msgPack5)
err = producer.Produce(ctx, &msgPack5)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack6)
_, err = producer.Broadcast(ctx, &msgPack6)
assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack7)
_, err = producer.Broadcast(ctx, &msgPack7)
assert.NoError(t, err)
// Test received message
@ -434,13 +434,13 @@ func testTimeTickUnmarshalHeader(t *testing.T, f []Factory) {
defer producer.Close()
defer consumer.Close()
_, err := producer.Broadcast(&msgPack0)
_, err := producer.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = producer.Produce(&msgPack1)
err = producer.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = producer.Broadcast(&msgPack2)
_, err = producer.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
@ -571,7 +571,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := producer.Produce(msgPack)
err := producer.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -605,7 +605,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := producer.Produce(msgPack)
err := producer.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -622,7 +622,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = producer.Produce(msgPack)
err = producer.Produce(ctx, msgPack)
assert.NoError(t, err)
result := consume(ctx, consumer2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
@ -642,7 +642,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := producer.Produce(msgPack)
err := producer.Produce(ctx, msgPack)
assert.NoError(t, err)
consumer2 := createLatestConsumer(ctx, t, f[1].NewMsgStream, channels)
defer consumer2.Close()
@ -653,7 +653,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = producer.Produce(msgPack)
err = producer.Produce(ctx, msgPack)
assert.NoError(t, err)
for i := 10; i < 20; i++ {
@ -673,7 +673,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
ids, err := producer.Broadcast(&msgPack0)
ids, err := producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(channels), len(ids))
@ -687,7 +687,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
ids, err = producer.Broadcast(&msgPack1)
ids, err = producer.Broadcast(ctx, &msgPack1)
assert.NoError(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(channels), len(ids))
@ -698,12 +698,12 @@ func testBroadcastMark(t *testing.T, f []Factory) {
}
// edge cases
_, err = producer.Broadcast(nil)
_, err = producer.Broadcast(ctx, nil)
assert.Error(t, err)
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
_, err = producer.Broadcast(&msgPack2)
_, err = producer.Broadcast(ctx, &msgPack2)
assert.Error(t, err)
}
@ -712,7 +712,7 @@ func applyBroadCastAndConsume(t *testing.T, msgPack *MsgPack, newer []streamNewe
defer producer.Close()
defer consumer.Close()
_, err := producer.Broadcast(msgPack)
_, err := producer.Broadcast(context.TODO(), msgPack)
assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)*channelNum)
}
@ -728,7 +728,7 @@ func applyProduceAndConsumeWithRepack(
defer producer.Close()
defer consumer.Close()
err := producer.Produce(msgPack)
err := producer.Produce(context.TODO(), msgPack)
assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
}
@ -743,7 +743,7 @@ func applyProduceAndConsume(
defer producer.Close()
defer consumer.Close()
err := producer.Produce(msgPack)
err := producer.Produce(context.TODO(), msgPack)
assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
}
@ -774,7 +774,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer,
func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream {
producer, err := newer(ctx)
assert.NoError(t, err)
producer.AsProducer(channels)
producer.AsProducer(ctx, channels)
return producer
}
@ -798,7 +798,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe
assert.NotEmpty(t, channels)
producer, err := newer[0](ctx)
assert.NoError(t, err)
producer.AsProducer(channels)
producer.AsProducer(ctx, channels)
consumer, err := newer[1](ctx)
assert.NoError(t, err)

View File

@ -74,9 +74,9 @@ func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context,
return _c
}
// AsProducer provides a mock function with given fields: channels
func (_m *MockMsgStream) AsProducer(channels []string) {
_m.Called(channels)
// AsProducer provides a mock function with given fields: ctx, channels
func (_m *MockMsgStream) AsProducer(ctx context.Context, channels []string) {
_m.Called(ctx, channels)
}
// MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer'
@ -85,14 +85,15 @@ type MockMsgStream_AsProducer_Call struct {
}
// AsProducer is a helper method to define mock.On call
// - ctx context.Context
// - channels []string
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
func (_e *MockMsgStream_Expecter) AsProducer(ctx interface{}, channels interface{}) *MockMsgStream_AsProducer_Call {
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", ctx, channels)}
}
func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call {
func (_c *MockMsgStream_AsProducer_Call) Run(run func(ctx context.Context, channels []string)) *MockMsgStream_AsProducer_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string))
run(args[0].(context.Context), args[1].([]string))
})
return _c
}
@ -102,14 +103,14 @@ func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call
return _c
}
func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func([]string)) *MockMsgStream_AsProducer_Call {
func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func(context.Context, []string)) *MockMsgStream_AsProducer_Call {
_c.Call.Return(run)
return _c
}
// Broadcast provides a mock function with given fields: _a0
func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID, error) {
ret := _m.Called(_a0)
// Broadcast provides a mock function with given fields: _a0, _a1
func (_m *MockMsgStream) Broadcast(_a0 context.Context, _a1 *MsgPack) (map[string][]common.MessageID, error) {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for Broadcast")
@ -117,19 +118,19 @@ func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID,
var r0 map[string][]common.MessageID
var r1 error
if rf, ok := ret.Get(0).(func(*MsgPack) (map[string][]common.MessageID, error)); ok {
return rf(_a0)
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) (map[string][]common.MessageID, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]common.MessageID); ok {
r0 = rf(_a0)
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) map[string][]common.MessageID); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]common.MessageID)
}
}
if rf, ok := ret.Get(1).(func(*MsgPack) error); ok {
r1 = rf(_a0)
if rf, ok := ret.Get(1).(func(context.Context, *MsgPack) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
@ -143,14 +144,15 @@ type MockMsgStream_Broadcast_Call struct {
}
// Broadcast is a helper method to define mock.On call
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
// - _a0 context.Context
// - _a1 *MsgPack
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}, _a1 interface{}) *MockMsgStream_Broadcast_Call {
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0, _a1)}
}
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call {
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Broadcast_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack))
run(args[0].(context.Context), args[1].(*MsgPack))
})
return _c
}
@ -160,7 +162,7 @@ func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]common.MessageID
return _c
}
func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(*MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call {
func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call {
_c.Call.Return(run)
return _c
}
@ -428,17 +430,17 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin
return _c
}
// Produce provides a mock function with given fields: _a0
func (_m *MockMsgStream) Produce(_a0 *MsgPack) error {
ret := _m.Called(_a0)
// Produce provides a mock function with given fields: _a0, _a1
func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for Produce")
}
var r0 error
if rf, ok := ret.Get(0).(func(*MsgPack) error); ok {
r0 = rf(_a0)
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
@ -452,14 +454,15 @@ type MockMsgStream_Produce_Call struct {
}
// Produce is a helper method to define mock.On call
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
// - _a0 context.Context
// - _a1 *MsgPack
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}, _a1 interface{}) *MockMsgStream_Produce_Call {
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0, _a1)}
}
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call {
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Produce_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack))
run(args[0].(context.Context), args[1].(*MsgPack))
})
return _c
}
@ -469,7 +472,7 @@ func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_C
return _c
}
func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *MockMsgStream_Produce_Call {
func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(context.Context, *MsgPack) error) *MockMsgStream_Produce_Call {
_c.Call.Return(run)
return _c
}

View File

@ -123,7 +123,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
}
// produce test data
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
// pick a seekPosition
@ -219,21 +219,21 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) {
inputStream := getKafkaInputStream(ctx, kafkaAddress, producerChannels)
outputStream := getKafkaTtOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack3)
err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4)
_, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack5)
err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6)
_, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7)
_, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream)
@ -450,7 +450,7 @@ func getKafkaInputStream(ctx context.Context, kafkaAddress string, producerChann
}
kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfigMap(config, nil, nil)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
}

View File

@ -121,7 +121,7 @@ func NewMqMsgStream(ctx context.Context,
}
// AsProducer create producer to send message to channels
func (ms *mqMsgStream) AsProducer(channels []string) {
func (ms *mqMsgStream) AsProducer(ctx context.Context, channels []string) {
for _, channel := range channels {
if len(channel) == 0 {
log.Error("MsgStream asProducer's channel is an empty string")
@ -129,7 +129,7 @@ func (ms *mqMsgStream) AsProducer(channels []string) {
}
fn := func() error {
pp, err := ms.client.CreateProducer(common.ProducerOptions{Topic: channel, EnableCompression: true})
pp, err := ms.client.CreateProducer(ctx, common.ProducerOptions{Topic: channel, EnableCompression: true})
if err != nil {
return err
}
@ -176,7 +176,7 @@ func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subNam
continue
}
fn := func() error {
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{
pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: channel,
SubscriptionName: subName,
SubscriptionInitialPosition: position,
@ -273,7 +273,7 @@ func (ms *mqMsgStream) isEnabledProduce() bool {
return ms.enableProduce.Load().(bool)
}
func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
if !ms.isEnabledProduce() {
log.Warn("can't produce the msg in the backup instance", zap.Stack("stack"))
return merr.ErrDenyProduceMsg
@ -346,7 +346,7 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
// BroadcastMark broadcast msg pack to all producers and returns corresponding msg id
// the returned message id serves as marking
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) (map[string][]MessageID, error) {
func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[string][]MessageID, error) {
ids := make(map[string][]MessageID)
if msgPack == nil || len(msgPack.Msgs) <= 0 {
return ids, errors.New("empty msgs")
@ -581,7 +581,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN
continue
}
fn := func() error {
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{
pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: channel,
SubscriptionName: subName,
SubscriptionInitialPosition: position,

View File

@ -130,12 +130,12 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
{
inputStream.EnableProduce(false)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.Error(t, err)
}
inputStream.EnableProduce(true)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -156,7 +156,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -178,7 +178,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -203,12 +203,12 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
{
inputStream.EnableProduce(false)
_, err := inputStream.Broadcast(&msgPack)
_, err := inputStream.Broadcast(ctx, &msgPack)
require.Error(t, err)
}
inputStream.EnableProduce(true)
_, err := inputStream.Broadcast(&msgPack)
_, err := inputStream.Broadcast(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs))
@ -230,7 +230,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -277,14 +277,14 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack)
err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs)*2)
@ -328,14 +328,14 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack)
err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs)*1)
@ -360,14 +360,14 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack)
err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs))
@ -395,13 +395,13 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -440,17 +440,17 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack3)
err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4)
_, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack5)
_, err = inputStream.Broadcast(ctx, &msgPack5)
assert.NoError(t, err)
o1 := consumer(ctx, outputStream)
@ -495,7 +495,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
}
// produce test data
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
// pick a seekPosition
@ -617,21 +617,21 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack3)
err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4)
_, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack5)
err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6)
_, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7)
_, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream)
@ -711,13 +711,13 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -748,16 +748,16 @@ func TestStream_PulsarTtMsgStream_DropCollection(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
_, err = inputStream.Broadcast(&msgPack3)
_, err = inputStream.Broadcast(ctx, &msgPack3)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, 2)
@ -803,12 +803,12 @@ func sendMsgPacks(ms MsgStream, msgPacks []*MsgPack) error {
printMsgPack(msgPacks[i])
if i%2 == 0 {
// insert msg use Produce
if err := ms.Produce(msgPacks[i]); err != nil {
if err := ms.Produce(context.TODO(), msgPacks[i]); err != nil {
return err
}
} else {
// tt msg use Broadcast
if _, err := ms.Broadcast(msgPacks[i]); err != nil {
if _, err := ms.Broadcast(context.TODO(), msgPacks[i]); err != nil {
return err
}
}
@ -971,7 +971,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -1015,7 +1015,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -1049,7 +1049,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = inputStream.Produce(msgPack)
err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
@ -1074,7 +1074,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -1086,7 +1086,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
// produce timetick for mqtt msgstream seek
msgPack = &MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(1000))
err = inputStream.Produce(msgPack)
err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
factory := ProtoUDFactory{}
@ -1139,7 +1139,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
@ -1152,7 +1152,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = inputStream.Produce(msgPack)
err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
for i := 10; i < 20; i++ {
@ -1169,6 +1169,7 @@ func TestStream_BroadcastMark(t *testing.T) {
c1 := funcutil.RandomString(8)
c2 := funcutil.RandomString(8)
producerChannels := []string{c1, c2}
ctx := context.Background()
factory := ProtoUDFactory{}
pulsarClient, err := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
@ -1177,12 +1178,12 @@ func TestStream_BroadcastMark(t *testing.T) {
assert.NoError(t, err)
// add producer channels
outputStream.AsProducer(producerChannels)
outputStream.AsProducer(ctx, producerChannels)
msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
ids, err := outputStream.Broadcast(&msgPack0)
ids, err := outputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids))
@ -1196,7 +1197,7 @@ func TestStream_BroadcastMark(t *testing.T) {
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
ids, err = outputStream.Broadcast(&msgPack1)
ids, err = outputStream.Broadcast(ctx, &msgPack1)
assert.NoError(t, err)
assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids))
@ -1207,19 +1208,19 @@ func TestStream_BroadcastMark(t *testing.T) {
}
// edge cases
_, err = outputStream.Broadcast(nil)
_, err = outputStream.Broadcast(ctx, nil)
assert.Error(t, err)
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
_, err = outputStream.Broadcast(&msgPack2)
_, err = outputStream.Broadcast(ctx, &msgPack2)
assert.Error(t, err)
// mock send fail
for k, p := range outputStream.producers {
outputStream.producers[k] = &mockSendFailProducer{Producer: p}
}
_, err = outputStream.Broadcast(&msgPack1)
_, err = outputStream.Broadcast(ctx, &msgPack1)
assert.Error(t, err)
outputStream.Close()
@ -1497,7 +1498,7 @@ func getPulsarInputStream(ctx context.Context, pulsarAddress string, producerCha
factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
}

View File

@ -52,7 +52,7 @@ func TestMqMsgStream_AsProducer(t *testing.T) {
assert.NoError(t, err)
// empty channel name
m.AsProducer([]string{""})
m.AsProducer(context.TODO(), []string{""})
}
// TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage
@ -121,7 +121,7 @@ func TestMqMsgStream_GetProduceChannels(t *testing.T) {
assert.Equal(t, 0, len(chs))
// not empty after AsProducer
m.AsProducer([]string{"a"})
m.AsProducer(context.TODO(), []string{"a"})
chs = m.GetProduceChannels()
assert.Equal(t, 1, len(chs))
}
@ -160,7 +160,7 @@ func TestMqMsgStream_Produce(t *testing.T) {
msgPack := &MsgPack{
Msgs: []TsMsg{insertMsg},
}
err = m.Produce(msgPack)
err = m.Produce(context.TODO(), msgPack)
assert.Error(t, err)
}
@ -173,7 +173,7 @@ func TestMqMsgStream_Broadcast(t *testing.T) {
assert.NoError(t, err)
// Broadcast nil pointer
_, err = m.Broadcast(nil)
_, err = m.Broadcast(context.TODO(), nil)
assert.Error(t, err)
}
@ -241,7 +241,7 @@ func initRmqStream(ctx context.Context,
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
}
@ -265,7 +265,7 @@ func initRmqTtStream(ctx context.Context,
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts {
inputStream.SetRepackFunc(opt)
}
@ -290,7 +290,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
ctx := context.Background()
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(&msgPack)
err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -316,13 +316,13 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -355,13 +355,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack2)
err = inputStream.Produce(ctx, &msgPack2)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack3)
_, err = inputStream.Broadcast(ctx, &msgPack3)
assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream)
@ -425,21 +425,21 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0)
_, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack1)
err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2)
_, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack3)
err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4)
_, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err)
err = inputStream.Produce(&msgPack5)
err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6)
_, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7)
_, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream)
@ -512,7 +512,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err := inputStream.Produce(msgPack)
err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
@ -546,7 +546,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
err = inputStream.Produce(msgPack)
err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
result := consumer(ctx, outputStream2)
@ -560,27 +560,28 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
producerChannels := []string{"insert1"}
consumerChannels := []string{"insert1"}
consumerSubName := "subInsert"
ctx := context.Background()
factory := ProtoUDFactory{}
rmqClient, _ := rmq.NewClientWithDefaultOptions(context.Background())
otherInputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
otherInputStream.AsProducer([]string{"root_timetick"})
otherInputStream.Produce(getTimeTickMsgPack(999))
otherInputStream.AsProducer(context.TODO(), []string{"root_timetick"})
otherInputStream.Produce(ctx, getTimeTickMsgPack(999))
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels)
inputStream.AsProducer(context.TODO(), producerChannels)
for i := 0; i < 100; i++ {
inputStream.Produce(getTimeTickMsgPack(int64(i)))
inputStream.Produce(ctx, getTimeTickMsgPack(int64(i)))
}
rmqClient2, _ := rmq.NewClientWithDefaultOptions(context.Background())
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqcommon.SubscriptionPositionLatest)
inputStream.Produce(getTimeTickMsgPack(1000))
inputStream.Produce(ctx, getTimeTickMsgPack(1000))
pack := <-outputStream.Chan()
assert.NotNil(t, pack)
assert.Equal(t, 1, len(pack.Msgs))

View File

@ -17,16 +17,18 @@
package mqwrapper
import (
"context"
"github.com/milvus-io/milvus/pkg/mq/common"
)
// Client is the interface that provides operations of message queues
type Client interface {
// CreateProducer creates a producer instance
CreateProducer(options common.ProducerOptions) (Producer, error)
CreateProducer(ctx context.Context, options common.ProducerOptions) (Producer, error)
// Subscribe creates a consumer instance and subscribe a topic
Subscribe(options ConsumerOptions) (Consumer, error)
Subscribe(ctx context.Context, options ConsumerOptions) (Consumer, error)
// Get the earliest MessageID
EarliestMessageID() common.MessageID

View File

@ -205,7 +205,7 @@ func (kc *kafkaClient) newConsumerConfig(group string, offset common.Subscriptio
return newConf
}
func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
func (kc *kafkaClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -224,7 +224,7 @@ func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper
return producer, nil
}
func (kc *kafkaClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
func (kc *kafkaClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -64,7 +64,7 @@ func BytesToInt(b []byte) int {
// Consume1 will consume random messages and record the last MessageID it received
func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -103,7 +103,7 @@ func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
// Consume2 will consume messages from specified MessageID
func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -139,7 +139,7 @@ func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
}
func Consume3(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -418,7 +418,7 @@ func createConsumer(t *testing.T,
groupID string,
initPosition mqcommon.SubscriptionInitialPosition,
) mqwrapper.Consumer {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := kc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: groupID,
BufSize: 1024,
@ -429,7 +429,7 @@ func createConsumer(t *testing.T,
}
func createProducer(t *testing.T, kc *kafkaClient, topic string) mqwrapper.Producer {
producer, err := kc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
producer, err := kc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)
return producer

View File

@ -23,7 +23,7 @@ func TestKafkaProducer_SendSuccess(t *testing.T) {
rand.Seed(time.Now().UnixNano())
topic := fmt.Sprintf("test-topic-%d", rand.Int())
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic})
producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -76,7 +76,7 @@ func TestKafkaProducer_SendFailAfterClose(t *testing.T) {
rand.Seed(time.Now().UnixNano())
topic := fmt.Sprintf("test-topic-%d", rand.Int())
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic})
producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.Nil(t, err)
assert.NotNil(t, producer)

View File

@ -80,7 +80,7 @@ func NewClient(url string, options ...nats.Option) (*nmqClient, error) {
}
// CreateProducer creates a producer for natsmq client
func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
func (nc *nmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -112,7 +112,7 @@ func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
return &rp, nil
}
func (nc *nmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
func (nc *nmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -86,7 +86,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
topic := "TestNmqClient_CreateProducer"
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
assert.NotNil(t, producer)
defer producer.Close()
@ -102,7 +102,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
assert.NoError(t, err)
invalidOpts := common.ProducerOptions{Topic: ""}
producer, e := client.CreateProducer(invalidOpts)
producer, e := client.CreateProducer(context.TODO(), invalidOpts)
assert.Nil(t, producer)
assert.Error(t, e)
}
@ -114,7 +114,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
defer producer.Close()
@ -135,7 +135,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
BufSize: 1024,
}
consumer, err := client.Subscribe(consumerOpts)
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err)
expectLastMsg, err := consumer.GetLatestMsgID()
@ -166,13 +166,13 @@ func TestNmqClient_IllegalSubscribe(t *testing.T) {
assert.NotNil(t, client)
defer client.Close()
sub, err := client.Subscribe(mqwrapper.ConsumerOptions{
sub, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: "",
})
assert.Nil(t, sub)
assert.Error(t, err)
sub, err = client.Subscribe(mqwrapper.ConsumerOptions{
sub, err = client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: "123",
SubscriptionName: "",
})
@ -188,7 +188,7 @@ func TestNmqClient_Subscribe(t *testing.T) {
topic := "TestNmqClient_Subscribe"
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
assert.NotNil(t, producer)
defer producer.Close()
@ -201,12 +201,12 @@ func TestNmqClient_Subscribe(t *testing.T) {
BufSize: 1024,
}
consumer, err := client.Subscribe(consumerOpts)
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err)
assert.Nil(t, consumer)
consumerOpts.Topic = topic
consumer, err = client.Subscribe(consumerOpts)
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err)
assert.NotNil(t, consumer)
defer consumer.Close()

View File

@ -36,10 +36,10 @@ func TestNatsConsumer_Subscription(t *testing.T) {
topic := t.Name()
proOpts := common.ProducerOptions{Topic: topic}
_, err = client.CreateProducer(proOpts)
_, err = client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -69,7 +69,7 @@ func Test_BadLatestMessageID(t *testing.T) {
assert.NoError(t, err)
defer client.Close()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -88,10 +88,10 @@ func TestComsumeMessage(t *testing.T) {
defer client.Close()
topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -149,7 +149,7 @@ func TestNatsConsumer_Close(t *testing.T) {
defer client.Close()
topic := t.Name()
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -177,7 +177,7 @@ func TestNatsClientErrorOnUnsubscribeTwice(t *testing.T) {
assert.NoError(t, err)
defer client.Close()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -199,7 +199,7 @@ func TestCheckTopicValid(t *testing.T) {
defer client.Close()
topic := t.Name()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -220,7 +220,7 @@ func TestCheckTopicValid(t *testing.T) {
assert.Error(t, err)
// not empty topic can pass
pub, err := client.CreateProducer(common.ProducerOptions{
pub, err := client.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topic,
})
assert.NoError(t, err)
@ -240,7 +240,7 @@ func TestCheckTopicValid(t *testing.T) {
func newTestConsumer(t *testing.T, topic string, position common.SubscriptionInitialPosition) (mqwrapper.Consumer, error) {
client, err := createNmqClient()
assert.NoError(t, err)
return client.Subscribe(mqwrapper.ConsumerOptions{
return client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: position,
@ -251,7 +251,7 @@ func newTestConsumer(t *testing.T, topic string, position common.SubscriptionIni
func newProducer(t *testing.T, topic string) (*nmqClient, mqwrapper.Producer) {
client, err := createNmqClient()
assert.NoError(t, err)
producer, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
producer, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
return client, producer
}
@ -272,10 +272,10 @@ func TestNmqConsumer_GetLatestMsgID(t *testing.T) {
defer client.Close()
topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -301,13 +301,13 @@ func TestNmqConsumer_ConsumeFromLatest(t *testing.T) {
defer client.Close()
topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
msgs := []string{"111", "222", "333"}
process(t, msgs, p)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionLatest,
@ -331,13 +331,13 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
defer client.Close()
topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
msgs := []string{"111", "222"}
process(t, msgs, p)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -354,7 +354,7 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
msg = <-c.Chan()
assert.Equal(t, "222", string(msg.Payload()))
c2, err := client.Subscribe(mqwrapper.ConsumerOptions{
c2, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,

View File

@ -3,7 +3,7 @@
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// "License"); you may not use this file exceapt in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0

View File

@ -33,7 +33,7 @@ func TestNatsMQProducer(t *testing.T) {
pOpts := common.ProducerOptions{Topic: topic}
// Check Topic()
p, err := c.CreateProducer(pOpts)
p, err := c.CreateProducer(context.TODO(), pOpts)
assert.NoError(t, err)
assert.Equal(t, p.(*nmqProducer).Topic(), topic)

View File

@ -17,6 +17,7 @@
package pulsar
import (
"context"
"fmt"
"sync"
"time"
@ -66,7 +67,7 @@ func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pul
}
// CreateProducer create a pulsar producer from options
func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrapper.Producer, error) {
func (pc *pulsarClient) CreateProducer(ctx context.Context, options mqcommon.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -102,7 +103,7 @@ func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrap
}
// Subscribe creates a pulsar consumer instance and subscribe a topic
func (pc *pulsarClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
func (pc *pulsarClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -78,7 +78,7 @@ func BytesToInt(b []byte) int {
}
func Produce(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, arr []int) {
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -110,7 +110,7 @@ func VerifyMessage(t *testing.T, msg mqcommon.Message) {
// Consume1 will consume random messages and record the last MessageID it received
func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -147,7 +147,7 @@ func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
// Consume2 will consume messages from specified MessageID
func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -181,7 +181,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
}
func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -420,7 +420,7 @@ func TestPulsarClient_SeekPosition(t *testing.T) {
topic := fmt.Sprintf("test-topic-%d", rand.Int())
subName := fmt.Sprintf("test-subname-%d", rand.Int())
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -498,7 +498,7 @@ func TestPulsarClient_SeekLatest(t *testing.T) {
topic := fmt.Sprintf("test-topic-%d", rand.Int())
subName := fmt.Sprintf("test-subname-%d", rand.Int())
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -671,7 +671,7 @@ func TestPulsarClient_SubscribeExclusiveFail(t *testing.T) {
client: &mockPulsarClient{},
}
_, err := pc.Subscribe(mqwrapper.ConsumerOptions{Topic: "test_topic_name"})
_, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{Topic: "test_topic_name"})
assert.Error(t, err)
assert.True(t, retry.IsRecoverable(err))
})
@ -686,7 +686,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
pulsarAddress := getPulsarAddress()
pc, err := NewClient(tenant, namespace, pulsar.ClientOptions{URL: pulsarAddress})
assert.NoError(t, err)
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
defer producer.Close()
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -695,7 +695,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, fullTopicName, producer.(*pulsarProducer).Topic())
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -713,7 +713,7 @@ func TestPulsarCtl(t *testing.T) {
pulsarAddress := getPulsarAddress()
pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
assert.NoError(t, err)
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -723,7 +723,7 @@ func TestPulsarCtl(t *testing.T) {
assert.NotNil(t, consumer)
defer consumer.Close()
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{
_, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -732,7 +732,7 @@ func TestPulsarCtl(t *testing.T) {
assert.Error(t, err)
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{
_, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,
@ -762,7 +762,7 @@ func TestPulsarCtl(t *testing.T) {
assert.NoError(t, err)
}
consumer2, err := pc.Subscribe(mqwrapper.ConsumerOptions{
consumer2, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: subName,
BufSize: 1024,

View File

@ -80,9 +80,9 @@ func TestComsumeCompressedMessage(t *testing.T) {
assert.NoError(t, err)
defer consumer.Close()
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics"})
producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics"})
assert.NoError(t, err)
compressProducer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true})
compressProducer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true})
assert.NoError(t, err)
msg := []byte("test message")

View File

@ -34,7 +34,7 @@ func TestPulsarProducer(t *testing.T) {
assert.NotNil(t, pc)
topic := "TEST"
producer, err := pc.CreateProducer(common.ProducerOptions{Topic: topic})
producer, err := pc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err)
assert.NotNil(t, producer)

View File

@ -58,7 +58,7 @@ func NewClient(opts client.Options) (*rmqClient, error) {
}
// CreateProducer creates a producer for rocksmq client
func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
func (rc *rmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -77,7 +77,7 @@ func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
}
// Subscribe subscribes a consumer in rmq client
func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
func (rc *rmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -65,7 +65,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
topic := "TestRmqClient_CreateProducer"
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
assert.NotNil(t, producer)
@ -83,7 +83,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
assert.NoError(t, err)
invalidOpts := common.ProducerOptions{Topic: ""}
producer, e := client.CreateProducer(invalidOpts)
producer, e := client.CreateProducer(context.TODO(), invalidOpts)
assert.Nil(t, producer)
assert.Error(t, e)
}
@ -95,7 +95,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
defer producer.Close()
@ -116,7 +116,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
BufSize: 1024,
}
consumer, err := client.Subscribe(consumerOpts)
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err)
expectLastMsg, err := consumer.GetLatestMsgID()
@ -149,7 +149,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
topic := "TestRmqClient_Subscribe"
proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts)
producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err)
assert.NotNil(t, producer)
defer producer.Close()
@ -161,7 +161,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
BufSize: 0,
}
consumer, err := client.Subscribe(consumerOpts)
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err)
assert.Nil(t, consumer)
@ -172,12 +172,12 @@ func TestRmqClient_Subscribe(t *testing.T) {
BufSize: 1024,
}
consumer, err = client.Subscribe(consumerOpts)
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err)
assert.Nil(t, consumer)
consumerOpts.Topic = topic
consumer, err = client.Subscribe(consumerOpts)
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
defer consumer.Close()
assert.NoError(t, err)
assert.NotNil(t, consumer)

View File

@ -55,11 +55,11 @@ type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, erro
type MsgStream interface {
Close()
AsProducer(channels []string)
Produce(*MsgPack) error
AsProducer(ctx context.Context, channels []string)
Produce(context.Context, *MsgPack) error
SetRepackFunc(repackFunc RepackFunc)
GetProduceChannels() []string
Broadcast(*MsgPack) (map[string][]MessageID, error)
Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error)
AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error
Chan() <-chan *MsgPack

View File

@ -36,7 +36,7 @@ func TestPulsarMsgUtil(t *testing.T) {
defer msgStream.Close()
// create a topic
msgStream.AsProducer([]string{"test"})
msgStream.AsProducer(ctx, []string{"test"})
UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"})
}

View File

@ -46,7 +46,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
go func() {
defer wg.Done()
p, err := mqClient.CreateProducer(common.ProducerOptions{
p, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topic,
})
assert.NoError(b, err)
@ -55,7 +55,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
}()
go func() {
defer wg.Done()
c, _ := mqClient.Subscribe(mqwrapper.ConsumerOptions{
c, _ := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic,
SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,

View File

@ -40,13 +40,13 @@ func testStreamOperation(t *testing.T, mqClient mqwrapper.Client) {
func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0],
})
defer producer.Close()
assert.NoError(t, err)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0],
SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -61,7 +61,7 @@ func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0],
})
defer producer.Close()
@ -69,7 +69,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
ids := sendMessages(context.Background(), t, producer, generateRandMessage(1024, 1000))
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0],
SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionLatest,
@ -90,7 +90,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0],
})
defer producer.Close()
@ -99,7 +99,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
cases := generateRandMessage(1024, 1000)
ids := sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0],
SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
@ -124,7 +124,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0],
})
defer producer.Close()
@ -133,7 +133,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
cases := generateRandMessage(1024, 1000)
ids := sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0],
SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
@ -158,7 +158,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0],
})
defer producer.Close()
@ -167,7 +167,7 @@ func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client)
cases := generateRandMessage(1024, 1000)
sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0],
SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,

View File

@ -1,5 +1,7 @@
package msgstream
import "context"
type WastedMockMsgStream struct {
MsgStream
AsProducerFunc func(channels []string)
@ -12,11 +14,11 @@ func NewWastedMockMsgStream() *WastedMockMsgStream {
return &WastedMockMsgStream{}
}
func (m WastedMockMsgStream) AsProducer(channels []string) {
func (m WastedMockMsgStream) AsProducer(ctx context.Context, channels []string) {
m.AsProducerFunc(channels)
}
func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[string][]MessageID, error) {
return m.BroadcastMarkFunc(pack)
}