diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 947fd658fb..3dd8a7062e 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -34,7 +34,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -45,11 +44,8 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { globalMetaCache = nil defer func() { globalMetaCache = cache }() - chMgr := newMockChannelsMgr() - chMgr.removeDMLStreamFuncType = func(collectionID UniqueID) error { - log.Debug("TestProxy_InvalidateCollectionMetaCache_remove_stream, remove dml stream") - return nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().removeDMLStream(mock.Anything).Return() node := &Proxy{chMgr: chMgr} node.stateCode.Store(commonpb.StateCode_Healthy) diff --git a/internal/proxy/mock_channels_mgr_test.go b/internal/proxy/mock_channels_mgr_test.go deleted file mode 100644 index 88bfe35236..0000000000 --- a/internal/proxy/mock_channels_mgr_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package proxy - -type ( - getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error) - removeDMLStreamFuncType = func(collectionID UniqueID) error -) - -type mockChannelsMgr struct { - channelsMgr - getChannelsFunc func(collectionID UniqueID) ([]pChan, error) - getVChannelsFuncType - removeDMLStreamFuncType -} - -func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) { - if m.getChannelsFunc != nil { - return m.getChannelsFunc(collectionID) - } - return nil, nil -} - -func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) { - if m.getVChannelsFuncType != nil { - return m.getVChannelsFuncType(collectionID) - } - return nil, nil -} - -func (m *mockChannelsMgr) removeDMLStream(collectionID UniqueID) { - if m.removeDMLStreamFuncType != nil { - m.removeDMLStreamFuncType(collectionID) - } -} - -func newMockChannelsMgr() *mockChannelsMgr { - return &mockChannelsMgr{} -} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 97b44e3711..be3b151f30 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4113,7 +4113,7 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { @@ -4133,7 +4133,7 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { @@ -4153,7 +4153,7 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index 2862a0500e..2a2edd07b5 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "fmt" "testing" "github.com/cockroachdb/errors" @@ -81,10 +80,8 @@ func TestDeleteTask_GetChannels(t *testing.T) { mock.AnythingOfType("string"), ).Return(collectionID, nil) globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) dt := deleteTask{ ctx: context.Background(), req: &milvuspb.DeleteRequest{ @@ -97,13 +94,6 @@ func TestDeleteTask_GetChannels(t *testing.T) { resChannels := dt.getChannels() assert.ElementsMatch(t, channels, resChannels) assert.ElementsMatch(t, channels, dt.pChannels) - - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = dt.getChannels() - assert.ElementsMatch(t, channels, resChannels) } func TestDeleteTask_PreExecute(t *testing.T) { diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index a5e95a71b4..ddc9390ea5 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -236,10 +235,8 @@ func TestInsertTask(t *testing.T) { mock.AnythingOfType("string"), ).Return(collectionID, nil) globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) it := insertTask{ ctx: context.Background(), insertMsg: &msgstream.InsertMsg{ @@ -254,12 +251,5 @@ func TestInsertTask(t *testing.T) { resChannels := it.getChannels() assert.ElementsMatch(t, channels, resChannels) assert.ElementsMatch(t, channels, it.pChannels) - - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = it.getChannels() - assert.ElementsMatch(t, channels, resChannels) }) } diff --git a/internal/proxy/task_scheduler_test.go b/internal/proxy/task_scheduler_test.go index 13db44ba11..2a04ea3199 100644 --- a/internal/proxy/task_scheduler_test.go +++ b/internal/proxy/task_scheduler_test.go @@ -575,10 +575,8 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) { run := func(wg *sync.WaitGroup) { defer wg.Done() - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) it := &insertTask{ ctx: context.Background(), insertMsg: &msgstream.InsertMsg{ @@ -593,9 +591,7 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) { assert.NoError(t, err) task := scheduler.scheduleDmTask() scheduler.dmQueue.AddActiveTask(task) - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } + chMgr.EXPECT().getChannels(mock.Anything).Return(nil, fmt.Errorf("mock err")) scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic } diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 2d41e2c35a..dd6cfda691 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -17,7 +17,6 @@ package proxy import ( "context" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -307,10 +306,8 @@ func TestUpsertTask(t *testing.T) { ).Return(collectionID, nil) globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) ut := upsertTask{ ctx: context.Background(), req: &milvuspb.UpsertRequest{ @@ -323,12 +320,5 @@ func TestUpsertTask(t *testing.T) { resChannels := ut.getChannels() assert.ElementsMatch(t, channels, resChannels) assert.ElementsMatch(t, channels, ut.pChannels) - - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = ut.getChannels() - assert.ElementsMatch(t, channels, resChannels) }) }