Add broker for datanode grpc operations (#27631)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/27649/head
congqixia 2023-10-11 17:03:34 +08:00 committed by GitHub
parent 722e3db6b8
commit cbb350c552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1392 additions and 0 deletions

View File

@ -422,6 +422,7 @@ generate-mockery-datacoord: getdeps
generate-mockery-datanode: getdeps
$(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage
$(INSTALL_PATH)/mockery --name=Broker --dir=$(PWD)/internal/datanode/broker --output=$(PWD)/internal/datanode/broker/ --filename=mock_broker.go --with-expecter --structname=MockBroker --outpkg=broker --inpackage
generate-mockery-metastore: getdeps
$(INSTALL_PATH)/mockery --name=RootCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_rootcoord_catalog.go --with-expecter --structname=RootCoordCatalog --outpkg=mocks

View File

@ -0,0 +1,52 @@
package broker
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// Broker is the interface for datanode to interact with other components.
type Broker interface {
RootCoord
DataCoord
}
type coordBroker struct {
*rootCoordBroker
*dataCoordBroker
}
func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient) Broker {
return &coordBroker{
rootCoordBroker: &rootCoordBroker{
client: rc,
},
dataCoordBroker: &dataCoordBroker{
client: dc,
},
}
}
// RootCoord is the interface wraps `RootCoord` grpc call
type RootCoord interface {
DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error)
ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error)
ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error
AllocTimestamp(ctx context.Context, num uint32) (ts uint64, count uint32, err error)
}
// DataCoord is the interface wraps `DataCoord` grpc call
type DataCoord interface {
AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error)
ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error
GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error)
UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error
SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error
DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error
}

View File

@ -0,0 +1,133 @@
package broker
import (
"context"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type dataCoordBroker struct {
client types.DataCoordClient
}
func (dc *dataCoordBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) {
req := &datapb.AssignSegmentIDRequest{
NodeID: paramtable.GetNodeID(),
PeerRole: typeutil.ProxyRole,
SegmentIDRequests: reqs,
}
resp, err := dc.client.AssignSegmentID(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to call datacoord AssignSegmentID", zap.Error(err))
return nil, err
}
return lo.Map(resp.GetSegIDAssignments(), func(result *datapb.SegmentIDAssignment, _ int) typeutil.UniqueID {
return result.GetSegID()
}), nil
}
func (dc *dataCoordBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error {
log := log.Ctx(ctx)
req := &datapb.ReportDataNodeTtMsgsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
Msgs: msgs,
}
resp, err := dc.client.ReportDataNodeTtMsgs(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to report datanodeTtMsgs", zap.Error(err))
return err
}
return nil
}
func (dc *dataCoordBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) {
log := log.Ctx(ctx).With(
zap.Int64s("segmentIDs", segmentIDs),
)
infoResp, err := dc.client.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
SegmentIDs: segmentIDs,
IncludeUnHealthy: true,
})
if err := merr.CheckRPCCall(infoResp, err); err != nil {
log.Warn("Fail to get SegmentInfo by ids from datacoord", zap.Error(err))
return nil, err
}
return infoResp.Infos, nil
}
func (dc *dataCoordBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error {
channelCPTs, _ := tsoutil.ParseTS(cp.GetTimestamp())
log := log.Ctx(ctx).With(
zap.String("channelName", channelName),
zap.Time("channelCheckpointTime", channelCPTs),
)
req := &datapb.UpdateChannelCheckpointRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
VChannel: channelName,
Position: cp,
}
resp, err := dc.client.UpdateChannelCheckpoint(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to update channel checkpoint", zap.Error(err))
return err
}
return nil
}
func (dc *dataCoordBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error {
log := log.Ctx(ctx)
resp, err := dc.client.SaveBinlogPaths(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to SaveBinlogPaths", zap.Error(err))
return err
}
return nil
}
func (dc *dataCoordBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error {
log := log.Ctx(ctx)
resp, err := dc.client.DropVirtualChannel(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_MetaFailed {
err = merr.WrapErrChannelNotFound(req.GetChannelName())
}
log.Warn("failed to SaveBinlogPaths", zap.Error(err))
return err
}
return nil
}

View File

@ -0,0 +1,296 @@
package broker
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
type dataCoordSuite struct {
suite.Suite
dc *mocks.MockDataCoordClient
broker Broker
}
func (s *dataCoordSuite) SetupSuite() {
paramtable.Init()
}
func (s *dataCoordSuite) SetupTest() {
s.dc = mocks.NewMockDataCoordClient(s.T())
s.broker = NewCoordBroker(nil, s.dc)
}
func (s *dataCoordSuite) resetMock() {
s.dc.AssertExpectations(s.T())
s.dc.ExpectedCalls = nil
}
func (s *dataCoordSuite) TestAssignSegmentID() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
reqs := []*datapb.SegmentIDRequest{
{CollectionID: 100, Count: 1000},
{CollectionID: 100, Count: 2000},
}
s.Run("normal_case", func() {
s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything).
Return(&datapb.AssignSegmentIDResponse{
Status: merr.Status(nil),
SegIDAssignments: lo.Map(reqs, func(req *datapb.SegmentIDRequest, _ int) *datapb.SegmentIDAssignment {
return &datapb.SegmentIDAssignment{
Status: merr.Status(nil),
SegID: 10001,
Count: req.GetCount(),
}
}),
}, nil)
segmentIDs, err := s.broker.AssignSegmentID(ctx, reqs...)
s.NoError(err)
s.Equal(len(segmentIDs), len(reqs))
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.AssignSegmentID(ctx, reqs...)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_failure_status", func() {
s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything).
Return(&datapb.AssignSegmentIDResponse{
Status: merr.Status(errors.New("mock")),
}, nil)
_, err := s.broker.AssignSegmentID(ctx, reqs...)
s.Error(err)
s.resetMock()
})
}
func (s *dataCoordSuite) TestReportTimeTick() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
msgs := []*msgpb.DataNodeTtMsg{
{Timestamp: 1000, ChannelName: "dml_0"},
{Timestamp: 2000, ChannelName: "dml_1"},
}
s.Run("normal_case", func() {
s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *datapb.ReportDataNodeTtMsgsRequest, _ ...grpc.CallOption) {
s.Equal(msgs, req.GetMsgs())
}).
Return(merr.Status(nil), nil)
err := s.broker.ReportTimeTick(ctx, msgs)
s.NoError(err)
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything).
Return(merr.Status(errors.New("mock")), nil)
err := s.broker.ReportTimeTick(ctx, msgs)
s.Error(err)
s.resetMock()
})
}
func (s *dataCoordSuite) TestGetSegmentInfo() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
segmentIDs := []int64{1, 2, 3}
s.Run("normal_case", func() {
s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *datapb.GetSegmentInfoRequest, _ ...grpc.CallOption) {
s.ElementsMatch(segmentIDs, req.GetSegmentIDs())
s.True(req.GetIncludeUnHealthy())
}).
Return(&datapb.GetSegmentInfoResponse{
Status: merr.Status(nil),
Infos: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentInfo {
return &datapb.SegmentInfo{ID: id}
}),
}, nil)
infos, err := s.broker.GetSegmentInfo(ctx, segmentIDs)
s.NoError(err)
s.ElementsMatch(segmentIDs, lo.Map(infos, func(info *datapb.SegmentInfo, _ int) int64 { return info.GetID() }))
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.GetSegmentInfo(ctx, segmentIDs)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_failure_status", func() {
s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).
Return(&datapb.GetSegmentInfoResponse{
Status: merr.Status(errors.New("mock")),
}, nil)
_, err := s.broker.GetSegmentInfo(ctx, segmentIDs)
s.Error(err)
s.resetMock()
})
}
func (s *dataCoordSuite) TestUpdateChannelCheckpoint() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
channelName := "dml_0"
checkpoint := &msgpb.MsgPosition{
ChannelName: channelName,
MsgID: []byte{1, 2, 3},
Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0),
}
s.Run("normal_case", func() {
s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *datapb.UpdateChannelCheckpointRequest, _ ...grpc.CallOption) {
s.Equal(channelName, req.GetVChannel())
cp := req.GetPosition()
s.Equal(checkpoint.MsgID, cp.GetMsgID())
s.Equal(checkpoint.ChannelName, cp.GetChannelName())
s.Equal(checkpoint.Timestamp, cp.GetTimestamp())
}).
Return(merr.Status(nil), nil)
err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint)
s.NoError(err)
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_failure_status", func() {
s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything).
Return(merr.Status(errors.New("mock")), nil)
err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint)
s.Error(err)
s.resetMock()
})
}
func (s *dataCoordSuite) TestSaveBinlogPaths() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := &datapb.SaveBinlogPathsRequest{
Channel: "dml_0",
}
s.Run("normal_case", func() {
s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *datapb.SaveBinlogPathsRequest, _ ...grpc.CallOption) {
s.Equal("dml_0", req.GetChannel())
}).
Return(merr.Status(nil), nil)
err := s.broker.SaveBinlogPaths(ctx, req)
s.NoError(err)
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
err := s.broker.SaveBinlogPaths(ctx, req)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_failure_status", func() {
s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).
Return(merr.Status(errors.New("mock")), nil)
err := s.broker.SaveBinlogPaths(ctx, req)
s.Error(err)
s.resetMock()
})
}
func (s *dataCoordSuite) TestDropVirtualChannel() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := &datapb.DropVirtualChannelRequest{
ChannelName: "dml_0",
}
s.Run("normal_case", func() {
s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *datapb.DropVirtualChannelRequest, _ ...grpc.CallOption) {
s.Equal("dml_0", req.GetChannelName())
}).
Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(nil)}, nil)
err := s.broker.DropVirtualChannel(ctx, req)
s.NoError(err)
s.resetMock()
})
s.Run("datacoord_return_error", func() {
s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
err := s.broker.DropVirtualChannel(ctx, req)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_failure_status", func() {
s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).
Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(errors.New("mock"))}, nil)
err := s.broker.DropVirtualChannel(ctx, req)
s.Error(err)
s.resetMock()
})
s.Run("datacoord_return_legacy_MetaFailed", func() {
s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).
Return(&datapb.DropVirtualChannelResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil)
err := s.broker.DropVirtualChannel(ctx, req)
s.Error(err)
s.ErrorIs(err, merr.ErrChannelNotFound)
s.resetMock()
})
}
func TestDataCoordBroker(t *testing.T) {
suite.Run(t, new(dataCoordSuite))
}

View File

@ -0,0 +1,555 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package broker
import (
context "context"
milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
datapb "github.com/milvus-io/milvus/internal/proto/datapb"
mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
)
// MockBroker is an autogenerated mock type for the Broker type
type MockBroker struct {
mock.Mock
}
type MockBroker_Expecter struct {
mock *mock.Mock
}
func (_m *MockBroker) EXPECT() *MockBroker_Expecter {
return &MockBroker_Expecter{mock: &_m.Mock}
}
// AllocTimestamp provides a mock function with given fields: ctx, num
func (_m *MockBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) {
ret := _m.Called(ctx, num)
var r0 uint64
var r1 uint32
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, uint32) (uint64, uint32, error)); ok {
return rf(ctx, num)
}
if rf, ok := ret.Get(0).(func(context.Context, uint32) uint64); ok {
r0 = rf(ctx, num)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, uint32) uint32); ok {
r1 = rf(ctx, num)
} else {
r1 = ret.Get(1).(uint32)
}
if rf, ok := ret.Get(2).(func(context.Context, uint32) error); ok {
r2 = rf(ctx, num)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// MockBroker_AllocTimestamp_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocTimestamp'
type MockBroker_AllocTimestamp_Call struct {
*mock.Call
}
// AllocTimestamp is a helper method to define mock.On call
// - ctx context.Context
// - num uint32
func (_e *MockBroker_Expecter) AllocTimestamp(ctx interface{}, num interface{}) *MockBroker_AllocTimestamp_Call {
return &MockBroker_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, num)}
}
func (_c *MockBroker_AllocTimestamp_Call) Run(run func(ctx context.Context, num uint32)) *MockBroker_AllocTimestamp_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(uint32))
})
return _c
}
func (_c *MockBroker_AllocTimestamp_Call) Return(ts uint64, count uint32, err error) *MockBroker_AllocTimestamp_Call {
_c.Call.Return(ts, count, err)
return _c
}
func (_c *MockBroker_AllocTimestamp_Call) RunAndReturn(run func(context.Context, uint32) (uint64, uint32, error)) *MockBroker_AllocTimestamp_Call {
_c.Call.Return(run)
return _c
}
// AssignSegmentID provides a mock function with given fields: ctx, reqs
func (_m *MockBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]int64, error) {
_va := make([]interface{}, len(reqs))
for _i := range reqs {
_va[_i] = reqs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 []int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)); ok {
return rf(ctx, reqs...)
}
if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) []int64); ok {
r0 = rf(ctx, reqs...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int64)
}
}
if rf, ok := ret.Get(1).(func(context.Context, ...*datapb.SegmentIDRequest) error); ok {
r1 = rf(ctx, reqs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_AssignSegmentID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignSegmentID'
type MockBroker_AssignSegmentID_Call struct {
*mock.Call
}
// AssignSegmentID is a helper method to define mock.On call
// - ctx context.Context
// - reqs ...*datapb.SegmentIDRequest
func (_e *MockBroker_Expecter) AssignSegmentID(ctx interface{}, reqs ...interface{}) *MockBroker_AssignSegmentID_Call {
return &MockBroker_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID",
append([]interface{}{ctx}, reqs...)...)}
}
func (_c *MockBroker_AssignSegmentID_Call) Run(run func(ctx context.Context, reqs ...*datapb.SegmentIDRequest)) *MockBroker_AssignSegmentID_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*datapb.SegmentIDRequest, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*datapb.SegmentIDRequest)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockBroker_AssignSegmentID_Call) Return(_a0 []int64, _a1 error) *MockBroker_AssignSegmentID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroker_AssignSegmentID_Call) RunAndReturn(run func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)) *MockBroker_AssignSegmentID_Call {
_c.Call.Return(run)
return _c
}
// DescribeCollection provides a mock function with given fields: ctx, collectionID, ts
func (_m *MockBroker) DescribeCollection(ctx context.Context, collectionID int64, ts uint64) (*milvuspb.DescribeCollectionResponse, error) {
ret := _m.Called(ctx, collectionID, ts)
var r0 *milvuspb.DescribeCollectionResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)); ok {
return rf(ctx, collectionID, ts)
}
if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) *milvuspb.DescribeCollectionResponse); ok {
r0 = rf(ctx, collectionID, ts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok {
r1 = rf(ctx, collectionID, ts)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection'
type MockBroker_DescribeCollection_Call struct {
*mock.Call
}
// DescribeCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - ts uint64
func (_e *MockBroker_Expecter) DescribeCollection(ctx interface{}, collectionID interface{}, ts interface{}) *MockBroker_DescribeCollection_Call {
return &MockBroker_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, collectionID, ts)}
}
func (_c *MockBroker_DescribeCollection_Call) Run(run func(ctx context.Context, collectionID int64, ts uint64)) *MockBroker_DescribeCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(uint64))
})
return _c
}
func (_c *MockBroker_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockBroker_DescribeCollection_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)) *MockBroker_DescribeCollection_Call {
_c.Call.Return(run)
return _c
}
// DropVirtualChannel provides a mock function with given fields: ctx, req
func (_m *MockBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroker_DropVirtualChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropVirtualChannel'
type MockBroker_DropVirtualChannel_Call struct {
*mock.Call
}
// DropVirtualChannel is a helper method to define mock.On call
// - ctx context.Context
// - req *datapb.DropVirtualChannelRequest
func (_e *MockBroker_Expecter) DropVirtualChannel(ctx interface{}, req interface{}) *MockBroker_DropVirtualChannel_Call {
return &MockBroker_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", ctx, req)}
}
func (_c *MockBroker_DropVirtualChannel_Call) Run(run func(ctx context.Context, req *datapb.DropVirtualChannelRequest)) *MockBroker_DropVirtualChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*datapb.DropVirtualChannelRequest))
})
return _c
}
func (_c *MockBroker_DropVirtualChannel_Call) Return(_a0 error) *MockBroker_DropVirtualChannel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroker_DropVirtualChannel_Call) RunAndReturn(run func(context.Context, *datapb.DropVirtualChannelRequest) error) *MockBroker_DropVirtualChannel_Call {
_c.Call.Return(run)
return _c
}
// GetSegmentInfo provides a mock function with given fields: ctx, segmentIDs
func (_m *MockBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) {
ret := _m.Called(ctx, segmentIDs)
var r0 []*datapb.SegmentInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]*datapb.SegmentInfo, error)); ok {
return rf(ctx, segmentIDs)
}
if rf, ok := ret.Get(0).(func(context.Context, []int64) []*datapb.SegmentInfo); ok {
r0 = rf(ctx, segmentIDs)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*datapb.SegmentInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok {
r1 = rf(ctx, segmentIDs)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo'
type MockBroker_GetSegmentInfo_Call struct {
*mock.Call
}
// GetSegmentInfo is a helper method to define mock.On call
// - ctx context.Context
// - segmentIDs []int64
func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentIDs interface{}) *MockBroker_GetSegmentInfo_Call {
return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, segmentIDs)}
}
func (_c *MockBroker_GetSegmentInfo_Call) Run(run func(ctx context.Context, segmentIDs []int64)) *MockBroker_GetSegmentInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]int64))
})
return _c
}
func (_c *MockBroker_GetSegmentInfo_Call) Return(_a0 []*datapb.SegmentInfo, _a1 error) *MockBroker_GetSegmentInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroker_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, []int64) ([]*datapb.SegmentInfo, error)) *MockBroker_GetSegmentInfo_Call {
_c.Call.Return(run)
return _c
}
// ReportImport provides a mock function with given fields: ctx, req
func (_m *MockBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroker_ReportImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportImport'
type MockBroker_ReportImport_Call struct {
*mock.Call
}
// ReportImport is a helper method to define mock.On call
// - ctx context.Context
// - req *rootcoordpb.ImportResult
func (_e *MockBroker_Expecter) ReportImport(ctx interface{}, req interface{}) *MockBroker_ReportImport_Call {
return &MockBroker_ReportImport_Call{Call: _e.mock.On("ReportImport", ctx, req)}
}
func (_c *MockBroker_ReportImport_Call) Run(run func(ctx context.Context, req *rootcoordpb.ImportResult)) *MockBroker_ReportImport_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*rootcoordpb.ImportResult))
})
return _c
}
func (_c *MockBroker_ReportImport_Call) Return(_a0 error) *MockBroker_ReportImport_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroker_ReportImport_Call) RunAndReturn(run func(context.Context, *rootcoordpb.ImportResult) error) *MockBroker_ReportImport_Call {
_c.Call.Return(run)
return _c
}
// ReportTimeTick provides a mock function with given fields: ctx, msgs
func (_m *MockBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error {
ret := _m.Called(ctx, msgs)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.DataNodeTtMsg) error); ok {
r0 = rf(ctx, msgs)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroker_ReportTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportTimeTick'
type MockBroker_ReportTimeTick_Call struct {
*mock.Call
}
// ReportTimeTick is a helper method to define mock.On call
// - ctx context.Context
// - msgs []*msgpb.DataNodeTtMsg
func (_e *MockBroker_Expecter) ReportTimeTick(ctx interface{}, msgs interface{}) *MockBroker_ReportTimeTick_Call {
return &MockBroker_ReportTimeTick_Call{Call: _e.mock.On("ReportTimeTick", ctx, msgs)}
}
func (_c *MockBroker_ReportTimeTick_Call) Run(run func(ctx context.Context, msgs []*msgpb.DataNodeTtMsg)) *MockBroker_ReportTimeTick_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]*msgpb.DataNodeTtMsg))
})
return _c
}
func (_c *MockBroker_ReportTimeTick_Call) Return(_a0 error) *MockBroker_ReportTimeTick_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroker_ReportTimeTick_Call) RunAndReturn(run func(context.Context, []*msgpb.DataNodeTtMsg) error) *MockBroker_ReportTimeTick_Call {
_c.Call.Return(run)
return _c
}
// SaveBinlogPaths provides a mock function with given fields: ctx, req
func (_m *MockBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroker_SaveBinlogPaths_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBinlogPaths'
type MockBroker_SaveBinlogPaths_Call struct {
*mock.Call
}
// SaveBinlogPaths is a helper method to define mock.On call
// - ctx context.Context
// - req *datapb.SaveBinlogPathsRequest
func (_e *MockBroker_Expecter) SaveBinlogPaths(ctx interface{}, req interface{}) *MockBroker_SaveBinlogPaths_Call {
return &MockBroker_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", ctx, req)}
}
func (_c *MockBroker_SaveBinlogPaths_Call) Run(run func(ctx context.Context, req *datapb.SaveBinlogPathsRequest)) *MockBroker_SaveBinlogPaths_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*datapb.SaveBinlogPathsRequest))
})
return _c
}
func (_c *MockBroker_SaveBinlogPaths_Call) Return(_a0 error) *MockBroker_SaveBinlogPaths_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroker_SaveBinlogPaths_Call) RunAndReturn(run func(context.Context, *datapb.SaveBinlogPathsRequest) error) *MockBroker_SaveBinlogPaths_Call {
_c.Call.Return(run)
return _c
}
// ShowPartitions provides a mock function with given fields: ctx, dbName, collectionName
func (_m *MockBroker) ShowPartitions(ctx context.Context, dbName string, collectionName string) (map[string]int64, error) {
ret := _m.Called(ctx, dbName, collectionName)
var r0 map[string]int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (map[string]int64, error)); ok {
return rf(ctx, dbName, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok {
r0 = rf(ctx, dbName, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]int64)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, dbName, collectionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions'
type MockBroker_ShowPartitions_Call struct {
*mock.Call
}
// ShowPartitions is a helper method to define mock.On call
// - ctx context.Context
// - dbName string
// - collectionName string
func (_e *MockBroker_Expecter) ShowPartitions(ctx interface{}, dbName interface{}, collectionName interface{}) *MockBroker_ShowPartitions_Call {
return &MockBroker_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, dbName, collectionName)}
}
func (_c *MockBroker_ShowPartitions_Call) Run(run func(ctx context.Context, dbName string, collectionName string)) *MockBroker_ShowPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockBroker_ShowPartitions_Call) Return(_a0 map[string]int64, _a1 error) *MockBroker_ShowPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroker_ShowPartitions_Call) RunAndReturn(run func(context.Context, string, string) (map[string]int64, error)) *MockBroker_ShowPartitions_Call {
_c.Call.Return(run)
return _c
}
// UpdateChannelCheckpoint provides a mock function with given fields: ctx, channelName, cp
func (_m *MockBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error {
ret := _m.Called(ctx, channelName, cp)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition) error); ok {
r0 = rf(ctx, channelName, cp)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroker_UpdateChannelCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelCheckpoint'
type MockBroker_UpdateChannelCheckpoint_Call struct {
*mock.Call
}
// UpdateChannelCheckpoint is a helper method to define mock.On call
// - ctx context.Context
// - channelName string
// - cp *msgpb.MsgPosition
func (_e *MockBroker_Expecter) UpdateChannelCheckpoint(ctx interface{}, channelName interface{}, cp interface{}) *MockBroker_UpdateChannelCheckpoint_Call {
return &MockBroker_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", ctx, channelName, cp)}
}
func (_c *MockBroker_UpdateChannelCheckpoint_Call) Run(run func(ctx context.Context, channelName string, cp *msgpb.MsgPosition)) *MockBroker_UpdateChannelCheckpoint_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition))
})
return _c
}
func (_c *MockBroker_UpdateChannelCheckpoint_Call) Return(_a0 error) *MockBroker_UpdateChannelCheckpoint_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroker_UpdateChannelCheckpoint_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition) error) *MockBroker_UpdateChannelCheckpoint_Call {
_c.Call.Return(run)
return _c
}
// NewMockBroker creates a new instance of MockBroker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockBroker(t interface {
mock.TestingT
Cleanup(func())
}) *MockBroker {
mock := &MockBroker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,114 @@
package broker
import (
"context"
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type rootCoordBroker struct {
client types.RootCoordClient
}
func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, timestamp typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collectionID),
zap.Uint64("timestamp", timestamp),
)
req := &milvuspb.DescribeCollectionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID,
TimeStamp: timestamp,
}
resp, err := rc.client.DescribeCollectionInternal(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to DescribeCollectionInternal", zap.Error(err))
return nil, err
}
return resp, nil
}
func (rc *rootCoordBroker) ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error) {
req := &milvuspb.ShowPartitionsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
DbName: dbName,
CollectionName: collectionName,
}
log := log.Ctx(ctx).With(
zap.String("dbName", dbName),
zap.String("collectionName", collectionName),
)
resp, err := rc.client.ShowPartitions(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err))
return nil, err
}
partitionNames := resp.GetPartitionNames()
partitionIDs := resp.GetPartitionIDs()
if len(partitionNames) != len(partitionIDs) {
log.Warn("partition names and ids are unequal",
zap.Int("partitionNameNumber", len(partitionNames)),
zap.Int("partitionIDNumber", len(partitionIDs)))
return nil, fmt.Errorf("partition names and ids are unequal, number of names: %d, number of ids: %d",
len(partitionNames), len(partitionIDs))
}
partitions := make(map[string]int64)
for i := 0; i < len(partitionNames); i++ {
partitions[partitionNames[i]] = partitionIDs[i]
}
return partitions, nil
}
func (rc *rootCoordBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) {
log := log.Ctx(ctx)
req := &rootcoordpb.AllocTimestampRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
Count: num,
}
resp, err := rc.client.AllocTimestamp(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to AllocTimestamp", zap.Error(err))
return 0, 0, err
}
return resp.GetTimestamp(), resp.GetCount(), nil
}
func (rc *rootCoordBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error {
log := log.Ctx(ctx)
resp, err := rc.client.ReportImport(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to ReportImport", zap.Error(err))
return err
}
return nil
}

View File

@ -0,0 +1,241 @@
package broker
import (
"context"
"math/rand"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
type rootCoordSuite struct {
suite.Suite
rc *mocks.MockRootCoordClient
broker Broker
}
func (s *rootCoordSuite) SetupSuite() {
paramtable.Init()
}
func (s *rootCoordSuite) SetupTest() {
s.rc = mocks.NewMockRootCoordClient(s.T())
s.broker = NewCoordBroker(s.rc, nil)
}
func (s *rootCoordSuite) resetMock() {
s.rc.AssertExpectations(s.T())
s.rc.ExpectedCalls = nil
}
func (s *rootCoordSuite) TestDescribeCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collectionID := int64(100)
timestamp := tsoutil.ComposeTSByTime(time.Now(), 0)
s.Run("normal_case", func() {
collName := "test_collection_name"
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) {
s.Equal(collectionID, req.GetCollectionID())
s.Equal(timestamp, req.GetTimeStamp())
}).
Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
CollectionID: collectionID,
CollectionName: collName,
}, nil)
resp, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.NoError(err)
s.Equal(collectionID, resp.GetCollectionID())
s.Equal(collName, resp.GetCollectionName())
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(errors.New("mocked")),
}, nil)
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestShowPartitions() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dbName := "defaultDB"
collName := "testCollection"
s.Run("normal_case", func() {
partitions := map[string]int64{
"part1": 1001,
"part2": 1002,
"part3": 1003,
}
names := lo.Keys(partitions)
ids := lo.Map(names, func(name string, _ int) int64 {
return partitions[name]
})
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *milvuspb.ShowPartitionsRequest, _ ...grpc.CallOption) {
s.Equal(dbName, req.GetDbName())
s.Equal(collName, req.GetCollectionName())
}).
Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: ids,
PartitionNames: names,
}, nil)
partNameIDs, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.NoError(err)
s.Equal(len(partitions), len(partNameIDs))
for name, id := range partitions {
result, ok := partNameIDs[name]
s.True(ok)
s.Equal(id, result)
}
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.Error(err)
s.resetMock()
})
s.Run("partition_id_name_not_match", func() {
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: []int64{1, 2},
PartitionNames: []string{"part1"},
}, nil)
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestAllocTimestamp() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("normal_case", func() {
num := rand.Intn(10) + 1
ts := tsoutil.ComposeTSByTime(time.Now(), 0)
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *rootcoordpb.AllocTimestampRequest, _ ...grpc.CallOption) {
s.EqualValues(num, req.GetCount())
}).
Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Status(nil),
Timestamp: ts,
Count: uint32(num),
}, nil)
timestamp, cnt, err := s.broker.AllocTimestamp(ctx, uint32(num))
s.NoError(err)
s.Equal(ts, timestamp)
s.EqualValues(num, cnt)
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, _, err := s.broker.AllocTimestamp(ctx, 1)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Return(&rootcoordpb.AllocTimestampResponse{Status: merr.Status(errors.New("mock"))}, nil)
_, _, err := s.broker.AllocTimestamp(ctx, 1)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestReportImport() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
taskID := rand.Int63()
req := &rootcoordpb.ImportResult{
Status: merr.Status(nil),
TaskId: taskID,
}
s.Run("normal_case", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *rootcoordpb.ImportResult, _ ...grpc.CallOption) {
s.Equal(taskID, req.GetTaskId())
}).
Return(merr.Status(nil), nil)
err := s.broker.ReportImport(ctx, req)
s.NoError(err)
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
err := s.broker.ReportImport(ctx, req)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Return(merr.Status(errors.New("mock")), nil)
err := s.broker.ReportImport(ctx, req)
s.Error(err)
s.resetMock()
})
}
func TestRootCoordBroker(t *testing.T) {
suite.Run(t, new(rootCoordSuite))
}