mirror of https://github.com/milvus-io/milvus.git
enhance: refine pular related mq interfaces (#38007)
issue: #35917 Refines the pulsar-related mq APIs to allow the ctx to be passed down Signed-off-by: tinswzy <zhenyuan.wei@zilliz.com>pull/37378/merge
parent
73aa95f596
commit
5768dbbb5d
|
@ -406,7 +406,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
|
|||
return
|
||||
}
|
||||
|
||||
segment := t.meta.GetHealthySegment(t.meta.ctx, signal.segmentID)
|
||||
segment := t.meta.GetHealthySegment(context.TODO(), signal.segmentID)
|
||||
if segment == nil {
|
||||
log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID))
|
||||
return
|
||||
|
|
|
@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
|
|||
return make(chan *msgstream.MsgPack, 100)
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) AsProducer(channels []string) {}
|
||||
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}
|
||||
|
||||
func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
|
||||
return nil
|
||||
|
@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
|
|||
return make([]string, 0)
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
|
||||
func (mtm *mockTtMsgStream) Produce(context.Context, *msgstream.MsgPack) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
func (mtm *mockTtMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ import (
|
|||
type channelsMgr interface {
|
||||
getChannels(collectionID UniqueID) ([]pChan, error)
|
||||
getVChannels(collectionID UniqueID) ([]vChan, error)
|
||||
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error)
|
||||
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
|
||||
removeDMLStream(collectionID UniqueID)
|
||||
removeAllDMLStream()
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
|
|||
return ok && streamInfos.stream != nil
|
||||
}
|
||||
|
||||
func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
|
||||
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
|
||||
var stream msgstream.MsgStream
|
||||
var err error
|
||||
|
||||
|
@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
|
|||
return nil, err
|
||||
}
|
||||
|
||||
stream.AsProducer(pchans)
|
||||
stream.AsProducer(ctx, pchans)
|
||||
if repack != nil {
|
||||
stream.SetRepackFunc(repack)
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) {
|
|||
|
||||
// createMsgStream create message stream for specified collection. Idempotent.
|
||||
// If stream already exists, directly return it and no error will be returned.
|
||||
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
mgr.mu.RLock()
|
||||
infos, ok := mgr.infos[collectionID]
|
||||
if ok && infos.stream != nil {
|
||||
|
@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr
|
|||
return nil, err
|
||||
}
|
||||
|
||||
stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
|
||||
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
|
||||
if err != nil {
|
||||
// What if stream created by other goroutines?
|
||||
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
|
@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea
|
|||
|
||||
// getOrCreateStream get message stream of specified collection.
|
||||
// If stream doesn't exist, call createMsgStream to create for it.
|
||||
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
if stream, err := mgr.lockGetStream(collectionID); err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
return mgr.createMsgStream(collectionID)
|
||||
return mgr.createMsgStream(ctx, collectionID)
|
||||
}
|
||||
|
||||
// removeStream remove the corresponding stream of the specified collection. Idempotent.
|
||||
|
@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
|
|||
return mgr.dmlChannelsMgr.getVChannels(collectionID)
|
||||
}
|
||||
|
||||
func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID)
|
||||
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
|
||||
}
|
||||
|
||||
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
|
||||
|
|
|
@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) {
|
|||
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(factory, nil, nil)
|
||||
_, err := createStream(context.TODO(), factory, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) {
|
|||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(factory, nil, nil)
|
||||
_, err := createStream(context.TODO(), factory, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) {
|
|||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
_, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
|||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.createMsgStream(100)
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, err := m.createMsgStream(100)
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
}()
|
||||
|
@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
|||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.createMsgStream(100)
|
||||
_, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
|||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
_, err := m.createMsgStream(100)
|
||||
_, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
|||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
stream, err := m.createMsgStream(100)
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
stream, err = m.getOrCreateStream(100)
|
||||
stream, err = m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
|
|||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.getOrCreateStream(100)
|
||||
stream, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
|
|||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getOrCreateStream(100)
|
||||
_, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
|
|||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
stream, err := m.getOrCreateStream(100)
|
||||
stream, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
|
|
@ -6323,7 +6323,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
messageIDsMap, err := msgStream.Broadcast(msgPack)
|
||||
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
|
|
|
@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
|
|||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
|
@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) {
|
|||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
|
@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) {
|
|||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
t.Run("create database fail", func(t *testing.T) {
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
|
@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) {
|
|||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
t.Run("drop database fail", func(t *testing.T) {
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
|
@ -1496,13 +1496,13 @@ func TestProxy_ReplicateMessage(t *testing.T) {
|
|||
factory := newMockMsgStreamFactory()
|
||||
msgStreamObj := msgstream.NewMockMsgStream(t)
|
||||
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().Close().Return()
|
||||
mockMsgID1 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
|
||||
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
|
||||
}, nil)
|
||||
|
||||
|
@ -1581,7 +1581,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
|
|||
|
||||
{
|
||||
broadcastMock.Unset()
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast"))
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
|
||||
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
|
@ -1590,7 +1590,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
|
|||
}
|
||||
{
|
||||
broadcastMock.Unset()
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"unit_test_replicate_message": {},
|
||||
}, nil)
|
||||
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
@ -78,9 +80,9 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
|
|||
return _c
|
||||
}
|
||||
|
||||
// getOrCreateDmlStream provides a mock function with given fields: collectionID
|
||||
func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.MsgStream, error) {
|
||||
ret := _m.Called(collectionID)
|
||||
// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for getOrCreateDmlStream")
|
||||
|
@ -88,19 +90,19 @@ func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.M
|
|||
|
||||
var r0 msgstream.MsgStream
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(int64) (msgstream.MsgStream, error)); ok {
|
||||
return rf(collectionID)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
|
||||
return rf(ctx, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(int64) msgstream.MsgStream); ok {
|
||||
r0 = rf(collectionID)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
|
||||
r0 = rf(ctx, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(msgstream.MsgStream)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(int64) error); ok {
|
||||
r1 = rf(collectionID)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
|
||||
r1 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -114,14 +116,15 @@ type MockChannelsMgr_getOrCreateDmlStream_Call struct {
|
|||
}
|
||||
|
||||
// getOrCreateDmlStream is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - collectionID int64
|
||||
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", collectionID)}
|
||||
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64))
|
||||
run(args[0].(context.Context), args[1].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -131,7 +134,7 @@ func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStr
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ type mockMsgStream struct {
|
|||
enableProduce func(bool)
|
||||
}
|
||||
|
||||
func (m *mockMsgStream) AsProducer(producers []string) {
|
||||
func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) {
|
||||
if m.asProducer != nil {
|
||||
m.asProducer(producers)
|
||||
}
|
||||
|
|
|
@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
|
|||
return ms.msgChan
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) AsProducer(channels []string) {
|
||||
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
|
||||
|
@ -283,7 +283,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
|
|||
ms.increaseMsgCount(-delta)
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
|
||||
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
|
||||
defer ms.increaseMsgCount(1)
|
||||
|
||||
ms.msgChan <- pack
|
||||
|
@ -291,7 +291,7 @@ func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
return map[string][]msgstream.MessageID{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ func (node *Proxy) Init() error {
|
|||
return err
|
||||
}
|
||||
node.replicateMsgStream.EnableProduce(true)
|
||||
node.replicateMsgStream.AsProducer([]string{replicateMsgChannel})
|
||||
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
|
||||
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
|
||||
if err != nil {
|
||||
|
|
|
@ -34,15 +34,15 @@ func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, r
|
|||
return manager
|
||||
}
|
||||
|
||||
func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.NewResourceFunc {
|
||||
func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
|
||||
return func() (resource.Resource, error) {
|
||||
msgStream, err := m.factory.NewMsgStream(m.ctx)
|
||||
msgStream, err := m.factory.NewMsgStream(ctx)
|
||||
if err != nil {
|
||||
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
msgStream.SetRepackFunc(replicatePackFunc)
|
||||
msgStream.AsProducer([]string{channel})
|
||||
msgStream.AsProducer(ctx, []string{channel})
|
||||
msgStream.EnableProduce(true)
|
||||
|
||||
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
|
||||
|
@ -55,7 +55,7 @@ func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.N
|
|||
|
||||
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
|
||||
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
|
||||
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(channel))
|
||||
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
|
||||
return nil, err
|
||||
|
|
|
@ -142,7 +142,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
|
||||
stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID)
|
||||
stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
zap.Int64("taskID", dt.ID()),
|
||||
zap.Duration("prepare duration", dt.tr.RecordSpan()))
|
||||
|
||||
err = stream.Produce(msgPack)
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -161,7 +161,7 @@ func TestDeleteTask_Execute(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(nil, errors.New("mock error"))
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
|
||||
|
@ -190,7 +190,7 @@ func TestDeleteTask_Execute(t *testing.T) {
|
|||
primaryKeys: pk,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
|
@ -226,8 +226,8 @@ func TestDeleteTask_Execute(t *testing.T) {
|
|||
primaryKeys: pk,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error"))
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
}
|
||||
|
@ -535,9 +535,9 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
},
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
stream.EXPECT().Produce(mock.Anything).Return(fmt.Errorf("mock error"))
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error"))
|
||||
|
||||
assert.Error(t, dr.Run(context.Background()))
|
||||
assert.Equal(t, int64(0), dr.result.DeleteCnt)
|
||||
|
@ -644,9 +644,9 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
},
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
stream.EXPECT().Produce(mock.Anything).Return(nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
|
@ -768,7 +768,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
},
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
|
@ -792,7 +792,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
server.FinishSend(nil)
|
||||
return client
|
||||
}, nil)
|
||||
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error"))
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
|
||||
assert.Error(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(0), dr.result.DeleteCnt)
|
||||
|
@ -830,7 +830,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
},
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
|
@ -854,7 +854,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
server.FinishSend(nil)
|
||||
return client
|
||||
}, nil)
|
||||
stream.EXPECT().Produce(mock.Anything).Return(nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
assert.NoError(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(3), dr.result.DeleteCnt)
|
||||
|
@ -911,7 +911,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
},
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
|
@ -936,7 +936,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
|||
return client
|
||||
}, nil)
|
||||
|
||||
stream.EXPECT().Produce(mock.Anything).Return(nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
assert.NoError(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(3), dr.result.DeleteCnt)
|
||||
})
|
||||
|
|
|
@ -243,7 +243,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
it.insertMsg.CollectionID = collID
|
||||
|
||||
getCacheDur := tr.RecordSpan()
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(collID)
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -280,7 +280,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
|
||||
log.Debug("assign segmentID for insert data success",
|
||||
zap.Duration("assign segmentID duration", assignSegmentIDDur))
|
||||
err = stream.Produce(msgPack)
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Warn("fail to produce insert msg", zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
|
|
|
@ -1755,7 +1755,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(collectionID)
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
@ -2004,7 +2004,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(collectionID)
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
@ -3460,7 +3460,7 @@ func TestPartitionKey(t *testing.T) {
|
|||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(collectionID)
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -393,7 +393,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP
|
|||
zap.Int64("collectionID", collID))
|
||||
getCacheDur := tr.RecordSpan()
|
||||
|
||||
_, err = it.chMgr.getOrCreateDmlStream(collID)
|
||||
_, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -526,7 +526,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
|
|||
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID)
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -547,7 +547,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
err = stream.Produce(msgPack)
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
|
|
|
@ -1985,7 +1985,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
|
|||
EndTs: ts,
|
||||
Msgs: []msgstream.TsMsg{tsMsg},
|
||||
}
|
||||
msgErr := replicateMsgStream.Produce(msgPack)
|
||||
msgErr := replicateMsgStream.Produce(ctx, msgPack)
|
||||
// ignore the error if the msg stream failed to produce the msg,
|
||||
// because it can be manually fixed in this error
|
||||
if msgErr != nil {
|
||||
|
|
|
@ -2430,7 +2430,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("produce fail", func(t *testing.T) {
|
||||
mockStream.EXPECT().Produce(mock.Anything).Return(errors.New("produce error")).Once()
|
||||
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
IsReplicate: true,
|
||||
|
@ -2444,7 +2444,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
mockStream.EXPECT().Produce(mock.Anything).Return(nil)
|
||||
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})
|
||||
|
|
|
@ -188,7 +188,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
|
|||
d.checkPreCreatedTopic(ctx, factory, name)
|
||||
}
|
||||
|
||||
ms.AsProducer([]string{name})
|
||||
ms.AsProducer(ctx, []string{name})
|
||||
dms := &dmlMsgStream{
|
||||
ms: ms,
|
||||
refcnt: 0,
|
||||
|
@ -291,7 +291,7 @@ func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) err
|
|||
|
||||
dms.mutex.RLock()
|
||||
if dms.refcnt > 0 {
|
||||
if _, err := dms.ms.Broadcast(pack); err != nil {
|
||||
if _, err := dms.ms.Broadcast(d.ctx, pack); err != nil {
|
||||
log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName))
|
||||
dms.mutex.RUnlock()
|
||||
return err
|
||||
|
@ -312,7 +312,7 @@ func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack)
|
|||
|
||||
dms.mutex.RLock()
|
||||
if dms.refcnt > 0 {
|
||||
ids, err := dms.ms.Broadcast(pack)
|
||||
ids, err := dms.ms.Broadcast(d.ctx, pack)
|
||||
if err != nil {
|
||||
log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName))
|
||||
dms.mutex.RUnlock()
|
||||
|
|
|
@ -277,17 +277,17 @@ type FailMsgStream struct {
|
|||
errBroadcast bool
|
||||
}
|
||||
|
||||
func (ms *FailMsgStream) Close() {}
|
||||
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) AsProducer(channels []string) {}
|
||||
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
|
||||
func (ms *FailMsgStream) Close() {}
|
||||
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
|
||||
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
|
||||
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
|
||||
return nil
|
||||
}
|
||||
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
|
||||
func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
|
||||
func (ms *FailMsgStream) Produce(*msgstream.MsgPack) error { return nil }
|
||||
func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
|
||||
func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
|
||||
func (ms *FailMsgStream) Produce(context.Context, *msgstream.MsgPack) error { return nil }
|
||||
func (ms *FailMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
if ms.errBroadcast {
|
||||
return nil, errors.New("broadcast error")
|
||||
}
|
||||
|
|
|
@ -42,8 +42,8 @@ func TestInputNode(t *testing.T) {
|
|||
|
||||
msgPack := generateMsgPack()
|
||||
produceStream, _ := factory.NewMsgStream(context.TODO())
|
||||
produceStream.AsProducer(channels)
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.AsProducer(context.TODO(), channels)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
|
@ -84,7 +84,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
|
|||
msgStream.AsConsumer(context.Background(), channels, "sub", common.SubscriptionPositionEarliest)
|
||||
|
||||
produceStream, _ := factory.NewMsgStream(context.TODO())
|
||||
produceStream.AsProducer(channels)
|
||||
produceStream.AsProducer(context.TODO(), channels)
|
||||
closeCh := make(chan struct{})
|
||||
outputCh := make(chan bool)
|
||||
|
||||
|
@ -110,7 +110,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
|
|||
defer close(closeCh)
|
||||
|
||||
msgPack := generateMsgPack()
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
log.Info("produce empty ttmsg")
|
||||
<-outputCh
|
||||
assert.Equal(t, 1, outputCount)
|
||||
|
@ -118,7 +118,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
|
|||
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Equal(t, false, inputNode.skipMode)
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
log.Info("after 3 seconds with no active msg receive, input node will turn on skip mode")
|
||||
<-outputCh
|
||||
assert.Equal(t, 2, outputCount)
|
||||
|
@ -126,13 +126,13 @@ func Test_InputNodeSkipMode(t *testing.T) {
|
|||
|
||||
log.Info("some ttmsg will be skipped in skip mode")
|
||||
// this msg will be skipped
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
<-outputCh
|
||||
assert.Equal(t, 2, outputCount)
|
||||
assert.Equal(t, true, inputNode.skipMode)
|
||||
|
||||
// this msg will be consumed
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
<-outputCh
|
||||
assert.Equal(t, 3, outputCount)
|
||||
assert.Equal(t, true, inputNode.skipMode)
|
||||
|
|
|
@ -80,13 +80,13 @@ func TestNodeManager_Start(t *testing.T) {
|
|||
msgStream.AsConsumer(context.TODO(), channels, "sub", common.SubscriptionPositionEarliest)
|
||||
|
||||
produceStream, _ := factory.NewMsgStream(context.TODO())
|
||||
produceStream.AsProducer(channels)
|
||||
produceStream.AsProducer(context.TODO(), channels)
|
||||
|
||||
msgPack := generateMsgPack()
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
msgPack = generateMsgPack()
|
||||
produceStream.Produce(&msgPack)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
|
|
|
@ -226,7 +226,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
|
|||
insNum := rand.Intn(10)
|
||||
for j := 0; j < insNum; j++ {
|
||||
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
|
@ -237,7 +237,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
|
|||
delNum := rand.Intn(2)
|
||||
for j := 0; j < delNum; j++ {
|
||||
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
|
@ -247,7 +247,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
|
|||
// produce random ddl
|
||||
ddlNum := rand.Intn(2)
|
||||
for j := 0; j < ddlNum; j++ {
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
|
@ -257,7 +257,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
|
|||
}
|
||||
// produce time tick
|
||||
ts := uint64(i * 100)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
|
@ -305,7 +305,7 @@ func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
|
|||
return
|
||||
case <-ticker.C:
|
||||
ts := uint64(tt * 1000)
|
||||
err := suite.producer.Produce(&msgstream.MsgPack{
|
||||
err := suite.producer.Produce(ctx, &msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
|
||||
})
|
||||
assert.NoError(suite.T(), err)
|
||||
|
|
|
@ -55,7 +55,7 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream.AsProducer([]string{pchannel})
|
||||
stream.AsProducer(context.TODO(), []string{pchannel})
|
||||
stream.SetRepackFunc(defaultInsertRepackFunc)
|
||||
return stream, nil
|
||||
}
|
||||
|
|
|
@ -173,11 +173,11 @@ func testTimeTickerAndInsert(t *testing.T, f []Factory) {
|
|||
defer consumer.Close()
|
||||
|
||||
var err error
|
||||
_, err = producer.Broadcast(&msgPack0)
|
||||
_, err = producer.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack1)
|
||||
err = producer.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack2)
|
||||
_, err = producer.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
|
||||
|
@ -210,17 +210,17 @@ func testTimeTickerNoSeek(t *testing.T, f []Factory) {
|
|||
defer producer.Close()
|
||||
|
||||
var err error
|
||||
_, err = producer.Broadcast(&msgPack0)
|
||||
_, err = producer.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack1)
|
||||
err = producer.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack2)
|
||||
_, err = producer.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack3)
|
||||
err = producer.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack4)
|
||||
_, err = producer.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack5)
|
||||
_, err = producer.Broadcast(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
|
||||
o1 := consume(ctx, consumer)
|
||||
|
@ -259,7 +259,7 @@ func testSeekToLast(t *testing.T, f []Factory) {
|
|||
}
|
||||
|
||||
// produce test data
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// pick a seekPosition
|
||||
|
@ -346,21 +346,21 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
|
|||
defer producer.Close()
|
||||
|
||||
// Send message
|
||||
_, err := producer.Broadcast(&msgPack0)
|
||||
_, err := producer.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack1)
|
||||
err = producer.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack2)
|
||||
_, err = producer.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack3)
|
||||
err = producer.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack4)
|
||||
_, err = producer.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
err = producer.Produce(&msgPack5)
|
||||
err = producer.Produce(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack6)
|
||||
_, err = producer.Broadcast(ctx, &msgPack6)
|
||||
assert.NoError(t, err)
|
||||
_, err = producer.Broadcast(&msgPack7)
|
||||
_, err = producer.Broadcast(ctx, &msgPack7)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test received message
|
||||
|
@ -434,13 +434,13 @@ func testTimeTickUnmarshalHeader(t *testing.T, f []Factory) {
|
|||
defer producer.Close()
|
||||
defer consumer.Close()
|
||||
|
||||
_, err := producer.Broadcast(&msgPack0)
|
||||
_, err := producer.Broadcast(ctx, &msgPack0)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
err = producer.Produce(&msgPack1)
|
||||
err = producer.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = producer.Broadcast(&msgPack2)
|
||||
_, err = producer.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
|
||||
|
@ -571,7 +571,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -605,7 +605,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -622,7 +622,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
|
|||
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
err = producer.Produce(msgPack)
|
||||
err = producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
result := consume(ctx, consumer2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(1))
|
||||
|
@ -642,7 +642,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
consumer2 := createLatestConsumer(ctx, t, f[1].NewMsgStream, channels)
|
||||
defer consumer2.Close()
|
||||
|
@ -653,7 +653,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
|
|||
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
err = producer.Produce(msgPack)
|
||||
err = producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
|
@ -673,7 +673,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
|
|||
msgPack0 := MsgPack{}
|
||||
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
|
||||
|
||||
ids, err := producer.Broadcast(&msgPack0)
|
||||
ids, err := producer.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ids)
|
||||
assert.Equal(t, len(channels), len(ids))
|
||||
|
@ -687,7 +687,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
|
|||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
|
||||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
|
||||
|
||||
ids, err = producer.Broadcast(&msgPack1)
|
||||
ids, err = producer.Broadcast(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ids)
|
||||
assert.Equal(t, len(channels), len(ids))
|
||||
|
@ -698,12 +698,12 @@ func testBroadcastMark(t *testing.T, f []Factory) {
|
|||
}
|
||||
|
||||
// edge cases
|
||||
_, err = producer.Broadcast(nil)
|
||||
_, err = producer.Broadcast(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
msgPack2 := MsgPack{}
|
||||
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
|
||||
_, err = producer.Broadcast(&msgPack2)
|
||||
_, err = producer.Broadcast(ctx, &msgPack2)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -712,7 +712,7 @@ func applyBroadCastAndConsume(t *testing.T, msgPack *MsgPack, newer []streamNewe
|
|||
defer producer.Close()
|
||||
defer consumer.Close()
|
||||
|
||||
_, err := producer.Broadcast(msgPack)
|
||||
_, err := producer.Broadcast(context.TODO(), msgPack)
|
||||
assert.NoError(t, err)
|
||||
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)*channelNum)
|
||||
}
|
||||
|
@ -728,7 +728,7 @@ func applyProduceAndConsumeWithRepack(
|
|||
defer producer.Close()
|
||||
defer consumer.Close()
|
||||
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(context.TODO(), msgPack)
|
||||
assert.NoError(t, err)
|
||||
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
|
||||
}
|
||||
|
@ -743,7 +743,7 @@ func applyProduceAndConsume(
|
|||
defer producer.Close()
|
||||
defer consumer.Close()
|
||||
|
||||
err := producer.Produce(msgPack)
|
||||
err := producer.Produce(context.TODO(), msgPack)
|
||||
assert.NoError(t, err)
|
||||
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
|
||||
}
|
||||
|
@ -774,7 +774,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer,
|
|||
func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream {
|
||||
producer, err := newer(ctx)
|
||||
assert.NoError(t, err)
|
||||
producer.AsProducer(channels)
|
||||
producer.AsProducer(ctx, channels)
|
||||
return producer
|
||||
}
|
||||
|
||||
|
@ -798,7 +798,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe
|
|||
assert.NotEmpty(t, channels)
|
||||
producer, err := newer[0](ctx)
|
||||
assert.NoError(t, err)
|
||||
producer.AsProducer(channels)
|
||||
producer.AsProducer(ctx, channels)
|
||||
|
||||
consumer, err := newer[1](ctx)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -74,9 +74,9 @@ func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context,
|
|||
return _c
|
||||
}
|
||||
|
||||
// AsProducer provides a mock function with given fields: channels
|
||||
func (_m *MockMsgStream) AsProducer(channels []string) {
|
||||
_m.Called(channels)
|
||||
// AsProducer provides a mock function with given fields: ctx, channels
|
||||
func (_m *MockMsgStream) AsProducer(ctx context.Context, channels []string) {
|
||||
_m.Called(ctx, channels)
|
||||
}
|
||||
|
||||
// MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer'
|
||||
|
@ -85,14 +85,15 @@ type MockMsgStream_AsProducer_Call struct {
|
|||
}
|
||||
|
||||
// AsProducer is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - channels []string
|
||||
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
|
||||
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
|
||||
func (_e *MockMsgStream_Expecter) AsProducer(ctx interface{}, channels interface{}) *MockMsgStream_AsProducer_Call {
|
||||
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", ctx, channels)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call {
|
||||
func (_c *MockMsgStream_AsProducer_Call) Run(run func(ctx context.Context, channels []string)) *MockMsgStream_AsProducer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]string))
|
||||
run(args[0].(context.Context), args[1].([]string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -102,14 +103,14 @@ func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func([]string)) *MockMsgStream_AsProducer_Call {
|
||||
func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func(context.Context, []string)) *MockMsgStream_AsProducer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Broadcast provides a mock function with given fields: _a0
|
||||
func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID, error) {
|
||||
ret := _m.Called(_a0)
|
||||
// Broadcast provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockMsgStream) Broadcast(_a0 context.Context, _a1 *MsgPack) (map[string][]common.MessageID, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Broadcast")
|
||||
|
@ -117,19 +118,19 @@ func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID,
|
|||
|
||||
var r0 map[string][]common.MessageID
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(*MsgPack) (map[string][]common.MessageID, error)); ok {
|
||||
return rf(_a0)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) (map[string][]common.MessageID, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]common.MessageID); ok {
|
||||
r0 = rf(_a0)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) map[string][]common.MessageID); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string][]common.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(*MsgPack) error); ok {
|
||||
r1 = rf(_a0)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *MsgPack) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -143,14 +144,15 @@ type MockMsgStream_Broadcast_Call struct {
|
|||
}
|
||||
|
||||
// Broadcast is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
|
||||
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
|
||||
// - _a0 context.Context
|
||||
// - _a1 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}, _a1 interface{}) *MockMsgStream_Broadcast_Call {
|
||||
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call {
|
||||
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Broadcast_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*MsgPack))
|
||||
run(args[0].(context.Context), args[1].(*MsgPack))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -160,7 +162,7 @@ func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]common.MessageID
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(*MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call {
|
||||
func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
@ -428,17 +430,17 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin
|
|||
return _c
|
||||
}
|
||||
|
||||
// Produce provides a mock function with given fields: _a0
|
||||
func (_m *MockMsgStream) Produce(_a0 *MsgPack) error {
|
||||
ret := _m.Called(_a0)
|
||||
// Produce provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Produce")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*MsgPack) error); ok {
|
||||
r0 = rf(_a0)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) error); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
@ -452,14 +454,15 @@ type MockMsgStream_Produce_Call struct {
|
|||
}
|
||||
|
||||
// Produce is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
|
||||
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
|
||||
// - _a0 context.Context
|
||||
// - _a1 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}, _a1 interface{}) *MockMsgStream_Produce_Call {
|
||||
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call {
|
||||
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Produce_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*MsgPack))
|
||||
run(args[0].(context.Context), args[1].(*MsgPack))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -469,7 +472,7 @@ func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_C
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *MockMsgStream_Produce_Call {
|
||||
func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(context.Context, *MsgPack) error) *MockMsgStream_Produce_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -123,7 +123,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
|
|||
}
|
||||
|
||||
// produce test data
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// pick a seekPosition
|
||||
|
@ -219,21 +219,21 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) {
|
|||
inputStream := getKafkaInputStream(ctx, kafkaAddress, producerChannels)
|
||||
outputStream := getKafkaTtOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack3)
|
||||
err = inputStream.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack4)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack5)
|
||||
err = inputStream.Produce(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack6)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack6)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack7)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack7)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedMsg := consumer(ctx, outputStream)
|
||||
|
@ -450,7 +450,7 @@ func getKafkaInputStream(ctx context.Context, kafkaAddress string, producerChann
|
|||
}
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfigMap(config, nil, nil)
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
for _, opt := range opts {
|
||||
inputStream.SetRepackFunc(opt)
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ func NewMqMsgStream(ctx context.Context,
|
|||
}
|
||||
|
||||
// AsProducer create producer to send message to channels
|
||||
func (ms *mqMsgStream) AsProducer(channels []string) {
|
||||
func (ms *mqMsgStream) AsProducer(ctx context.Context, channels []string) {
|
||||
for _, channel := range channels {
|
||||
if len(channel) == 0 {
|
||||
log.Error("MsgStream asProducer's channel is an empty string")
|
||||
|
@ -129,7 +129,7 @@ func (ms *mqMsgStream) AsProducer(channels []string) {
|
|||
}
|
||||
|
||||
fn := func() error {
|
||||
pp, err := ms.client.CreateProducer(common.ProducerOptions{Topic: channel, EnableCompression: true})
|
||||
pp, err := ms.client.CreateProducer(ctx, common.ProducerOptions{Topic: channel, EnableCompression: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subNam
|
|||
continue
|
||||
}
|
||||
fn := func() error {
|
||||
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: channel,
|
||||
SubscriptionName: subName,
|
||||
SubscriptionInitialPosition: position,
|
||||
|
@ -273,7 +273,7 @@ func (ms *mqMsgStream) isEnabledProduce() bool {
|
|||
return ms.enableProduce.Load().(bool)
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
|
||||
func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
|
||||
if !ms.isEnabledProduce() {
|
||||
log.Warn("can't produce the msg in the backup instance", zap.Stack("stack"))
|
||||
return merr.ErrDenyProduceMsg
|
||||
|
@ -346,7 +346,7 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
|
|||
|
||||
// BroadcastMark broadcast msg pack to all producers and returns corresponding msg id
|
||||
// the returned message id serves as marking
|
||||
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) (map[string][]MessageID, error) {
|
||||
func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[string][]MessageID, error) {
|
||||
ids := make(map[string][]MessageID)
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
return ids, errors.New("empty msgs")
|
||||
|
@ -581,7 +581,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN
|
|||
continue
|
||||
}
|
||||
fn := func() error {
|
||||
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: channel,
|
||||
SubscriptionName: subName,
|
||||
SubscriptionInitialPosition: position,
|
||||
|
|
|
@ -130,12 +130,12 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
|
|||
|
||||
{
|
||||
inputStream.EnableProduce(false)
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
inputStream.EnableProduce(true)
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
|
||||
|
@ -156,7 +156,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
|
||||
|
@ -178,7 +178,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
|
||||
|
@ -203,12 +203,12 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
|
|||
|
||||
{
|
||||
inputStream.EnableProduce(false)
|
||||
_, err := inputStream.Broadcast(&msgPack)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
inputStream.EnableProduce(true)
|
||||
_, err := inputStream.Broadcast(&msgPack)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs))
|
||||
|
@ -230,7 +230,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc)
|
||||
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
|
||||
|
@ -277,14 +277,14 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
err := (*inputStream).Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, output, len(msgPack.Msgs)*2)
|
||||
|
@ -328,14 +328,14 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
err := (*inputStream).Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, output, len(msgPack.Msgs)*1)
|
||||
|
@ -360,14 +360,14 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
err := (*inputStream).Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, output, len(msgPack.Msgs))
|
||||
|
@ -395,13 +395,13 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
|
||||
|
@ -440,17 +440,17 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack3)
|
||||
err = inputStream.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack4)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack5)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
|
||||
o1 := consumer(ctx, outputStream)
|
||||
|
@ -495,7 +495,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
|
|||
}
|
||||
|
||||
// produce test data
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// pick a seekPosition
|
||||
|
@ -617,21 +617,21 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack3)
|
||||
err = inputStream.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack4)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack5)
|
||||
err = inputStream.Produce(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack6)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack6)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack7)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack7)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedMsg := consumer(ctx, outputStream)
|
||||
|
@ -711,13 +711,13 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
|
||||
|
@ -748,16 +748,16 @@ func TestStream_PulsarTtMsgStream_DropCollection(t *testing.T) {
|
|||
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
|
||||
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(&msgPack3)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack3)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, 2)
|
||||
|
@ -803,12 +803,12 @@ func sendMsgPacks(ms MsgStream, msgPacks []*MsgPack) error {
|
|||
printMsgPack(msgPacks[i])
|
||||
if i%2 == 0 {
|
||||
// insert msg use Produce
|
||||
if err := ms.Produce(msgPacks[i]); err != nil {
|
||||
if err := ms.Produce(context.TODO(), msgPacks[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// tt msg use Broadcast
|
||||
if _, err := ms.Broadcast(msgPacks[i]); err != nil {
|
||||
if _, err := ms.Broadcast(context.TODO(), msgPacks[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -971,7 +971,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -1015,7 +1015,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -1049,7 +1049,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
|||
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
err = inputStream.Produce(msgPack)
|
||||
err = inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
result := consumer(ctx, outputStream2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(1))
|
||||
|
@ -1074,7 +1074,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -1086,7 +1086,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
|
|||
// produce timetick for mqtt msgstream seek
|
||||
msgPack = &MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(1000))
|
||||
err = inputStream.Produce(msgPack)
|
||||
err = inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
factory := ProtoUDFactory{}
|
||||
|
@ -1139,7 +1139,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
|
@ -1152,7 +1152,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
|
|||
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
err = inputStream.Produce(msgPack)
|
||||
err = inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
|
@ -1169,6 +1169,7 @@ func TestStream_BroadcastMark(t *testing.T) {
|
|||
c1 := funcutil.RandomString(8)
|
||||
c2 := funcutil.RandomString(8)
|
||||
producerChannels := []string{c1, c2}
|
||||
ctx := context.Background()
|
||||
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, err := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
|
@ -1177,12 +1178,12 @@ func TestStream_BroadcastMark(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// add producer channels
|
||||
outputStream.AsProducer(producerChannels)
|
||||
outputStream.AsProducer(ctx, producerChannels)
|
||||
|
||||
msgPack0 := MsgPack{}
|
||||
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
|
||||
|
||||
ids, err := outputStream.Broadcast(&msgPack0)
|
||||
ids, err := outputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ids)
|
||||
assert.Equal(t, len(producerChannels), len(ids))
|
||||
|
@ -1196,7 +1197,7 @@ func TestStream_BroadcastMark(t *testing.T) {
|
|||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
|
||||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
|
||||
|
||||
ids, err = outputStream.Broadcast(&msgPack1)
|
||||
ids, err = outputStream.Broadcast(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ids)
|
||||
assert.Equal(t, len(producerChannels), len(ids))
|
||||
|
@ -1207,19 +1208,19 @@ func TestStream_BroadcastMark(t *testing.T) {
|
|||
}
|
||||
|
||||
// edge cases
|
||||
_, err = outputStream.Broadcast(nil)
|
||||
_, err = outputStream.Broadcast(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
msgPack2 := MsgPack{}
|
||||
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
|
||||
_, err = outputStream.Broadcast(&msgPack2)
|
||||
_, err = outputStream.Broadcast(ctx, &msgPack2)
|
||||
assert.Error(t, err)
|
||||
|
||||
// mock send fail
|
||||
for k, p := range outputStream.producers {
|
||||
outputStream.producers[k] = &mockSendFailProducer{Producer: p}
|
||||
}
|
||||
_, err = outputStream.Broadcast(&msgPack1)
|
||||
_, err = outputStream.Broadcast(ctx, &msgPack1)
|
||||
assert.Error(t, err)
|
||||
|
||||
outputStream.Close()
|
||||
|
@ -1497,7 +1498,7 @@ func getPulsarInputStream(ctx context.Context, pulsarAddress string, producerCha
|
|||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
for _, opt := range opts {
|
||||
inputStream.SetRepackFunc(opt)
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ func TestMqMsgStream_AsProducer(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// empty channel name
|
||||
m.AsProducer([]string{""})
|
||||
m.AsProducer(context.TODO(), []string{""})
|
||||
}
|
||||
|
||||
// TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage
|
||||
|
@ -121,7 +121,7 @@ func TestMqMsgStream_GetProduceChannels(t *testing.T) {
|
|||
assert.Equal(t, 0, len(chs))
|
||||
|
||||
// not empty after AsProducer
|
||||
m.AsProducer([]string{"a"})
|
||||
m.AsProducer(context.TODO(), []string{"a"})
|
||||
chs = m.GetProduceChannels()
|
||||
assert.Equal(t, 1, len(chs))
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ func TestMqMsgStream_Produce(t *testing.T) {
|
|||
msgPack := &MsgPack{
|
||||
Msgs: []TsMsg{insertMsg},
|
||||
}
|
||||
err = m.Produce(msgPack)
|
||||
err = m.Produce(context.TODO(), msgPack)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -173,7 +173,7 @@ func TestMqMsgStream_Broadcast(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// Broadcast nil pointer
|
||||
_, err = m.Broadcast(nil)
|
||||
_, err = m.Broadcast(context.TODO(), nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -241,7 +241,7 @@ func initRmqStream(ctx context.Context,
|
|||
|
||||
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
for _, opt := range opts {
|
||||
inputStream.SetRepackFunc(opt)
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ func initRmqTtStream(ctx context.Context,
|
|||
|
||||
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
|
||||
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(ctx, producerChannels)
|
||||
for _, opt := range opts {
|
||||
inputStream.SetRepackFunc(opt)
|
||||
}
|
||||
|
@ -290,7 +290,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName)
|
||||
err := inputStream.Produce(&msgPack)
|
||||
err := inputStream.Produce(ctx, &msgPack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
|
||||
|
@ -316,13 +316,13 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
|
||||
|
@ -355,13 +355,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack2)
|
||||
err = inputStream.Produce(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack3)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedMsg := consumer(ctx, outputStream)
|
||||
|
@ -425,21 +425,21 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
|
||||
|
||||
_, err := inputStream.Broadcast(&msgPack0)
|
||||
_, err := inputStream.Broadcast(ctx, &msgPack0)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack1)
|
||||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack2)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack3)
|
||||
err = inputStream.Produce(ctx, &msgPack3)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack4)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack4)
|
||||
assert.NoError(t, err)
|
||||
err = inputStream.Produce(&msgPack5)
|
||||
err = inputStream.Produce(ctx, &msgPack5)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack6)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack6)
|
||||
assert.NoError(t, err)
|
||||
_, err = inputStream.Broadcast(&msgPack7)
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack7)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedMsg := consumer(ctx, outputStream)
|
||||
|
@ -512,7 +512,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
|||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
err := inputStream.Produce(msgPack)
|
||||
err := inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -546,7 +546,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
|||
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
err = inputStream.Produce(msgPack)
|
||||
err = inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result := consumer(ctx, outputStream2)
|
||||
|
@ -560,27 +560,28 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
|
|||
producerChannels := []string{"insert1"}
|
||||
consumerChannels := []string{"insert1"}
|
||||
consumerSubName := "subInsert"
|
||||
ctx := context.Background()
|
||||
|
||||
factory := ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := rmq.NewClientWithDefaultOptions(context.Background())
|
||||
|
||||
otherInputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
otherInputStream.AsProducer([]string{"root_timetick"})
|
||||
otherInputStream.Produce(getTimeTickMsgPack(999))
|
||||
otherInputStream.AsProducer(context.TODO(), []string{"root_timetick"})
|
||||
otherInputStream.Produce(ctx, getTimeTickMsgPack(999))
|
||||
|
||||
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
inputStream.AsProducer(context.TODO(), producerChannels)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
inputStream.Produce(getTimeTickMsgPack(int64(i)))
|
||||
inputStream.Produce(ctx, getTimeTickMsgPack(int64(i)))
|
||||
}
|
||||
|
||||
rmqClient2, _ := rmq.NewClientWithDefaultOptions(context.Background())
|
||||
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqcommon.SubscriptionPositionLatest)
|
||||
|
||||
inputStream.Produce(getTimeTickMsgPack(1000))
|
||||
inputStream.Produce(ctx, getTimeTickMsgPack(1000))
|
||||
pack := <-outputStream.Chan()
|
||||
assert.NotNil(t, pack)
|
||||
assert.Equal(t, 1, len(pack.Msgs))
|
||||
|
|
|
@ -17,16 +17,18 @@
|
|||
package mqwrapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
)
|
||||
|
||||
// Client is the interface that provides operations of message queues
|
||||
type Client interface {
|
||||
// CreateProducer creates a producer instance
|
||||
CreateProducer(options common.ProducerOptions) (Producer, error)
|
||||
CreateProducer(ctx context.Context, options common.ProducerOptions) (Producer, error)
|
||||
|
||||
// Subscribe creates a consumer instance and subscribe a topic
|
||||
Subscribe(options ConsumerOptions) (Consumer, error)
|
||||
Subscribe(ctx context.Context, options ConsumerOptions) (Consumer, error)
|
||||
|
||||
// Get the earliest MessageID
|
||||
EarliestMessageID() common.MessageID
|
||||
|
|
|
@ -205,7 +205,7 @@ func (kc *kafkaClient) newConsumerConfig(group string, offset common.Subscriptio
|
|||
return newConf
|
||||
}
|
||||
|
||||
func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
func (kc *kafkaClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
start := timerecord.NewTimeRecorder("create producer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
@ -224,7 +224,7 @@ func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper
|
|||
return producer, nil
|
||||
}
|
||||
|
||||
func (kc *kafkaClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
func (kc *kafkaClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
start := timerecord.NewTimeRecorder("create consumer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ func BytesToInt(b []byte) int {
|
|||
|
||||
// Consume1 will consume random messages and record the last MessageID it received
|
||||
func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
|
||||
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -103,7 +103,7 @@ func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
|
|||
|
||||
// Consume2 will consume messages from specified MessageID
|
||||
func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
|
||||
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -139,7 +139,7 @@ func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
|
|||
}
|
||||
|
||||
func Consume3(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, total *int) {
|
||||
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -418,7 +418,7 @@ func createConsumer(t *testing.T,
|
|||
groupID string,
|
||||
initPosition mqcommon.SubscriptionInitialPosition,
|
||||
) mqwrapper.Consumer {
|
||||
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := kc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: groupID,
|
||||
BufSize: 1024,
|
||||
|
@ -429,7 +429,7 @@ func createConsumer(t *testing.T,
|
|||
}
|
||||
|
||||
func createProducer(t *testing.T, kc *kafkaClient, topic string) mqwrapper.Producer {
|
||||
producer, err := kc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
|
||||
producer, err := kc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
return producer
|
||||
|
|
|
@ -23,7 +23,7 @@ func TestKafkaProducer_SendSuccess(t *testing.T) {
|
|||
rand.Seed(time.Now().UnixNano())
|
||||
topic := fmt.Sprintf("test-topic-%d", rand.Int())
|
||||
|
||||
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
@ -76,7 +76,7 @@ func TestKafkaProducer_SendFailAfterClose(t *testing.T) {
|
|||
rand.Seed(time.Now().UnixNano())
|
||||
topic := fmt.Sprintf("test-topic-%d", rand.Int())
|
||||
|
||||
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ func NewClient(url string, options ...nats.Option) (*nmqClient, error) {
|
|||
}
|
||||
|
||||
// CreateProducer creates a producer for natsmq client
|
||||
func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
func (nc *nmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
start := timerecord.NewTimeRecorder("create producer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
@ -112,7 +112,7 @@ func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
|
|||
return &rp, nil
|
||||
}
|
||||
|
||||
func (nc *nmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
func (nc *nmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
start := timerecord.NewTimeRecorder("create consumer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
|
|||
|
||||
topic := "TestNmqClient_CreateProducer"
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
defer producer.Close()
|
||||
|
@ -102,7 +102,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
invalidOpts := common.ProducerOptions{Topic: ""}
|
||||
producer, e := client.CreateProducer(invalidOpts)
|
||||
producer, e := client.CreateProducer(context.TODO(), invalidOpts)
|
||||
assert.Nil(t, producer)
|
||||
assert.Error(t, e)
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
|
|||
|
||||
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
defer producer.Close()
|
||||
|
||||
|
@ -135,7 +135,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
|
|||
BufSize: 1024,
|
||||
}
|
||||
|
||||
consumer, err := client.Subscribe(consumerOpts)
|
||||
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectLastMsg, err := consumer.GetLatestMsgID()
|
||||
|
@ -166,13 +166,13 @@ func TestNmqClient_IllegalSubscribe(t *testing.T) {
|
|||
assert.NotNil(t, client)
|
||||
defer client.Close()
|
||||
|
||||
sub, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
sub, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: "",
|
||||
})
|
||||
assert.Nil(t, sub)
|
||||
assert.Error(t, err)
|
||||
|
||||
sub, err = client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
sub, err = client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: "123",
|
||||
SubscriptionName: "",
|
||||
})
|
||||
|
@ -188,7 +188,7 @@ func TestNmqClient_Subscribe(t *testing.T) {
|
|||
|
||||
topic := "TestNmqClient_Subscribe"
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
defer producer.Close()
|
||||
|
@ -201,12 +201,12 @@ func TestNmqClient_Subscribe(t *testing.T) {
|
|||
BufSize: 1024,
|
||||
}
|
||||
|
||||
consumer, err := client.Subscribe(consumerOpts)
|
||||
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
consumerOpts.Topic = topic
|
||||
consumer, err = client.Subscribe(consumerOpts)
|
||||
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, consumer)
|
||||
defer consumer.Close()
|
||||
|
|
|
@ -36,10 +36,10 @@ func TestNatsConsumer_Subscription(t *testing.T) {
|
|||
|
||||
topic := t.Name()
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
_, err = client.CreateProducer(proOpts)
|
||||
_, err = client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -69,7 +69,7 @@ func Test_BadLatestMessageID(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -88,10 +88,10 @@ func TestComsumeMessage(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
|
||||
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -149,7 +149,7 @@ func TestNatsConsumer_Close(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -177,7 +177,7 @@ func TestNatsClientErrorOnUnsubscribeTwice(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -199,7 +199,7 @@ func TestCheckTopicValid(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -220,7 +220,7 @@ func TestCheckTopicValid(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
|
||||
// not empty topic can pass
|
||||
pub, err := client.CreateProducer(common.ProducerOptions{
|
||||
pub, err := client.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topic,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
@ -240,7 +240,7 @@ func TestCheckTopicValid(t *testing.T) {
|
|||
func newTestConsumer(t *testing.T, topic string, position common.SubscriptionInitialPosition) (mqwrapper.Consumer, error) {
|
||||
client, err := createNmqClient()
|
||||
assert.NoError(t, err)
|
||||
return client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
return client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: position,
|
||||
|
@ -251,7 +251,7 @@ func newTestConsumer(t *testing.T, topic string, position common.SubscriptionIni
|
|||
func newProducer(t *testing.T, topic string) (*nmqClient, mqwrapper.Producer) {
|
||||
client, err := createNmqClient()
|
||||
assert.NoError(t, err)
|
||||
producer, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
producer, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
return client, producer
|
||||
}
|
||||
|
@ -272,10 +272,10 @@ func TestNmqConsumer_GetLatestMsgID(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
|
||||
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -301,13 +301,13 @@ func TestNmqConsumer_ConsumeFromLatest(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
|
||||
msgs := []string{"111", "222", "333"}
|
||||
process(t, msgs, p)
|
||||
|
||||
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionLatest,
|
||||
|
@ -331,13 +331,13 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
|
|||
defer client.Close()
|
||||
|
||||
topic := t.Name()
|
||||
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
|
||||
msgs := []string{"111", "222"}
|
||||
process(t, msgs, p)
|
||||
|
||||
c, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -354,7 +354,7 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
|
|||
msg = <-c.Chan()
|
||||
assert.Equal(t, "222", string(msg.Payload()))
|
||||
|
||||
c2, err := client.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c2, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// "License"); you may not use this file exceapt in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestNatsMQProducer(t *testing.T) {
|
|||
pOpts := common.ProducerOptions{Topic: topic}
|
||||
|
||||
// Check Topic()
|
||||
p, err := c.CreateProducer(pOpts)
|
||||
p, err := c.CreateProducer(context.TODO(), pOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p.(*nmqProducer).Topic(), topic)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package pulsar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -66,7 +67,7 @@ func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pul
|
|||
}
|
||||
|
||||
// CreateProducer create a pulsar producer from options
|
||||
func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
func (pc *pulsarClient) CreateProducer(ctx context.Context, options mqcommon.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
start := timerecord.NewTimeRecorder("create producer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
@ -102,7 +103,7 @@ func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrap
|
|||
}
|
||||
|
||||
// Subscribe creates a pulsar consumer instance and subscribe a topic
|
||||
func (pc *pulsarClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
func (pc *pulsarClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
start := timerecord.NewTimeRecorder("create consumer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ func BytesToInt(b []byte) int {
|
|||
}
|
||||
|
||||
func Produce(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, arr []int) {
|
||||
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
|
||||
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
@ -110,7 +110,7 @@ func VerifyMessage(t *testing.T, msg mqcommon.Message) {
|
|||
|
||||
// Consume1 will consume random messages and record the last MessageID it received
|
||||
func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
|
||||
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -147,7 +147,7 @@ func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
|
|||
|
||||
// Consume2 will consume messages from specified MessageID
|
||||
func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
|
||||
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -181,7 +181,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
|
|||
}
|
||||
|
||||
func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, total *int) {
|
||||
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -420,7 +420,7 @@ func TestPulsarClient_SeekPosition(t *testing.T) {
|
|||
topic := fmt.Sprintf("test-topic-%d", rand.Int())
|
||||
subName := fmt.Sprintf("test-subname-%d", rand.Int())
|
||||
|
||||
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
|
||||
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
@ -498,7 +498,7 @@ func TestPulsarClient_SeekLatest(t *testing.T) {
|
|||
topic := fmt.Sprintf("test-topic-%d", rand.Int())
|
||||
subName := fmt.Sprintf("test-subname-%d", rand.Int())
|
||||
|
||||
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
|
||||
producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
@ -671,7 +671,7 @@ func TestPulsarClient_SubscribeExclusiveFail(t *testing.T) {
|
|||
client: &mockPulsarClient{},
|
||||
}
|
||||
|
||||
_, err := pc.Subscribe(mqwrapper.ConsumerOptions{Topic: "test_topic_name"})
|
||||
_, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{Topic: "test_topic_name"})
|
||||
assert.Error(t, err)
|
||||
assert.True(t, retry.IsRecoverable(err))
|
||||
})
|
||||
|
@ -686,7 +686,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
|
|||
pulsarAddress := getPulsarAddress()
|
||||
pc, err := NewClient(tenant, namespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
assert.NoError(t, err)
|
||||
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic})
|
||||
producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
|
||||
defer producer.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
@ -695,7 +695,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, fullTopicName, producer.(*pulsarProducer).Topic())
|
||||
|
||||
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -713,7 +713,7 @@ func TestPulsarCtl(t *testing.T) {
|
|||
pulsarAddress := getPulsarAddress()
|
||||
pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
assert.NoError(t, err)
|
||||
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -723,7 +723,7 @@ func TestPulsarCtl(t *testing.T) {
|
|||
assert.NotNil(t, consumer)
|
||||
defer consumer.Close()
|
||||
|
||||
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
_, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -732,7 +732,7 @@ func TestPulsarCtl(t *testing.T) {
|
|||
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
_, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
@ -762,7 +762,7 @@ func TestPulsarCtl(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
consumer2, err := pc.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer2, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: subName,
|
||||
BufSize: 1024,
|
||||
|
|
|
@ -80,9 +80,9 @@ func TestComsumeCompressedMessage(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
defer consumer.Close()
|
||||
|
||||
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics"})
|
||||
producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics"})
|
||||
assert.NoError(t, err)
|
||||
compressProducer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true})
|
||||
compressProducer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true})
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := []byte("test message")
|
||||
|
|
|
@ -34,7 +34,7 @@ func TestPulsarProducer(t *testing.T) {
|
|||
assert.NotNil(t, pc)
|
||||
|
||||
topic := "TEST"
|
||||
producer, err := pc.CreateProducer(common.ProducerOptions{Topic: topic})
|
||||
producer, err := pc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ func NewClient(opts client.Options) (*rmqClient, error) {
|
|||
}
|
||||
|
||||
// CreateProducer creates a producer for rocksmq client
|
||||
func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
func (rc *rmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
|
||||
start := timerecord.NewTimeRecorder("create producer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
|
|||
}
|
||||
|
||||
// Subscribe subscribes a consumer in rmq client
|
||||
func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
func (rc *rmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
|
||||
start := timerecord.NewTimeRecorder("create consumer")
|
||||
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
|
|||
|
||||
topic := "TestRmqClient_CreateProducer"
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
|
||||
|
@ -83,7 +83,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
invalidOpts := common.ProducerOptions{Topic: ""}
|
||||
producer, e := client.CreateProducer(invalidOpts)
|
||||
producer, e := client.CreateProducer(context.TODO(), invalidOpts)
|
||||
assert.Nil(t, producer)
|
||||
assert.Error(t, e)
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
|
|||
|
||||
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
defer producer.Close()
|
||||
|
||||
|
@ -116,7 +116,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
|
|||
BufSize: 1024,
|
||||
}
|
||||
|
||||
consumer, err := client.Subscribe(consumerOpts)
|
||||
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectLastMsg, err := consumer.GetLatestMsgID()
|
||||
|
@ -149,7 +149,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
|
|||
|
||||
topic := "TestRmqClient_Subscribe"
|
||||
proOpts := common.ProducerOptions{Topic: topic}
|
||||
producer, err := client.CreateProducer(proOpts)
|
||||
producer, err := client.CreateProducer(context.TODO(), proOpts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, producer)
|
||||
defer producer.Close()
|
||||
|
@ -161,7 +161,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
|
|||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
BufSize: 0,
|
||||
}
|
||||
consumer, err := client.Subscribe(consumerOpts)
|
||||
consumer, err := client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
|
@ -172,12 +172,12 @@ func TestRmqClient_Subscribe(t *testing.T) {
|
|||
BufSize: 1024,
|
||||
}
|
||||
|
||||
consumer, err = client.Subscribe(consumerOpts)
|
||||
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, consumer)
|
||||
|
||||
consumerOpts.Topic = topic
|
||||
consumer, err = client.Subscribe(consumerOpts)
|
||||
consumer, err = client.Subscribe(context.TODO(), consumerOpts)
|
||||
defer consumer.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, consumer)
|
||||
|
|
|
@ -55,11 +55,11 @@ type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, erro
|
|||
type MsgStream interface {
|
||||
Close()
|
||||
|
||||
AsProducer(channels []string)
|
||||
Produce(*MsgPack) error
|
||||
AsProducer(ctx context.Context, channels []string)
|
||||
Produce(context.Context, *MsgPack) error
|
||||
SetRepackFunc(repackFunc RepackFunc)
|
||||
GetProduceChannels() []string
|
||||
Broadcast(*MsgPack) (map[string][]MessageID, error)
|
||||
Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error)
|
||||
|
||||
AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error
|
||||
Chan() <-chan *MsgPack
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestPulsarMsgUtil(t *testing.T) {
|
|||
defer msgStream.Close()
|
||||
|
||||
// create a topic
|
||||
msgStream.AsProducer([]string{"test"})
|
||||
msgStream.AsProducer(ctx, []string{"test"})
|
||||
|
||||
UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"})
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
|
|||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
p, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topic,
|
||||
})
|
||||
assert.NoError(b, err)
|
||||
|
@ -55,7 +55,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
|
|||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c, _ := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
c, _ := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topic,
|
||||
SubscriptionName: topic,
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
|
|
@ -40,13 +40,13 @@ func testStreamOperation(t *testing.T, mqClient mqwrapper.Client) {
|
|||
func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
|
||||
topics := getChannel(2)
|
||||
|
||||
producer, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topics[0],
|
||||
})
|
||||
defer producer.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topics[0],
|
||||
SubscriptionName: funcutil.RandomString(8),
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
|
||||
|
@ -61,7 +61,7 @@ func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
|
|||
func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Client) {
|
||||
topics := getChannel(2)
|
||||
|
||||
producer, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topics[0],
|
||||
})
|
||||
defer producer.Close()
|
||||
|
@ -69,7 +69,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
|
|||
|
||||
ids := sendMessages(context.Background(), t, producer, generateRandMessage(1024, 1000))
|
||||
|
||||
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topics[0],
|
||||
SubscriptionName: funcutil.RandomString(8),
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionLatest,
|
||||
|
@ -90,7 +90,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
|
|||
func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Client) {
|
||||
topics := getChannel(2)
|
||||
|
||||
producer, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topics[0],
|
||||
})
|
||||
defer producer.Close()
|
||||
|
@ -99,7 +99,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
|
|||
cases := generateRandMessage(1024, 1000)
|
||||
ids := sendMessages(context.Background(), t, producer, cases)
|
||||
|
||||
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topics[0],
|
||||
SubscriptionName: funcutil.RandomString(8),
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
|
||||
|
@ -124,7 +124,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
|
|||
func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Client) {
|
||||
topics := getChannel(2)
|
||||
|
||||
producer, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topics[0],
|
||||
})
|
||||
defer producer.Close()
|
||||
|
@ -133,7 +133,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
|
|||
cases := generateRandMessage(1024, 1000)
|
||||
ids := sendMessages(context.Background(), t, producer, cases)
|
||||
|
||||
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topics[0],
|
||||
SubscriptionName: funcutil.RandomString(8),
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
|
||||
|
@ -158,7 +158,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
|
|||
func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) {
|
||||
topics := getChannel(2)
|
||||
|
||||
producer, err := mqClient.CreateProducer(common.ProducerOptions{
|
||||
producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
|
||||
Topic: topics[0],
|
||||
})
|
||||
defer producer.Close()
|
||||
|
@ -167,7 +167,7 @@ func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client)
|
|||
cases := generateRandMessage(1024, 1000)
|
||||
sendMessages(context.Background(), t, producer, cases)
|
||||
|
||||
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{
|
||||
consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
|
||||
Topic: topics[0],
|
||||
SubscriptionName: funcutil.RandomString(8),
|
||||
SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package msgstream
|
||||
|
||||
import "context"
|
||||
|
||||
type WastedMockMsgStream struct {
|
||||
MsgStream
|
||||
AsProducerFunc func(channels []string)
|
||||
|
@ -12,11 +14,11 @@ func NewWastedMockMsgStream() *WastedMockMsgStream {
|
|||
return &WastedMockMsgStream{}
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) AsProducer(channels []string) {
|
||||
func (m WastedMockMsgStream) AsProducer(ctx context.Context, channels []string) {
|
||||
m.AsProducerFunc(channels)
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
|
||||
func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[string][]MessageID, error) {
|
||||
return m.BroadcastMarkFunc(pack)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue