mirror of https://github.com/milvus-io/milvus.git
242 lines
6.3 KiB
Go
242 lines
6.3 KiB
Go
package broker
|
|
|
|
import (
|
|
"context"
|
|
"math/rand"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/samber/lo"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/mocks"
|
|
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
|
)
|
|
|
|
type rootCoordSuite struct {
|
|
suite.Suite
|
|
|
|
rc *mocks.MockRootCoordClient
|
|
broker Broker
|
|
}
|
|
|
|
func (s *rootCoordSuite) SetupSuite() {
|
|
paramtable.Init()
|
|
}
|
|
|
|
func (s *rootCoordSuite) SetupTest() {
|
|
s.rc = mocks.NewMockRootCoordClient(s.T())
|
|
s.broker = NewCoordBroker(s.rc, nil)
|
|
}
|
|
|
|
func (s *rootCoordSuite) resetMock() {
|
|
s.rc.AssertExpectations(s.T())
|
|
s.rc.ExpectedCalls = nil
|
|
}
|
|
|
|
func (s *rootCoordSuite) TestDescribeCollection() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
collectionID := int64(100)
|
|
timestamp := tsoutil.ComposeTSByTime(time.Now(), 0)
|
|
|
|
s.Run("normal_case", func() {
|
|
collName := "test_collection_name"
|
|
|
|
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
|
|
Run(func(_ context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) {
|
|
s.Equal(collectionID, req.GetCollectionID())
|
|
s.Equal(timestamp, req.GetTimeStamp())
|
|
}).
|
|
Return(&milvuspb.DescribeCollectionResponse{
|
|
Status: merr.Status(nil),
|
|
CollectionID: collectionID,
|
|
CollectionName: collName,
|
|
}, nil)
|
|
|
|
resp, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
|
|
s.NoError(err)
|
|
s.Equal(collectionID, resp.GetCollectionID())
|
|
s.Equal(collName, resp.GetCollectionName())
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_error", func() {
|
|
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
|
|
Return(nil, errors.New("mock"))
|
|
|
|
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_failure_status", func() {
|
|
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
|
|
Return(&milvuspb.DescribeCollectionResponse{
|
|
Status: merr.Status(errors.New("mocked")),
|
|
}, nil)
|
|
|
|
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
}
|
|
|
|
func (s *rootCoordSuite) TestShowPartitions() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
dbName := "defaultDB"
|
|
collName := "testCollection"
|
|
|
|
s.Run("normal_case", func() {
|
|
partitions := map[string]int64{
|
|
"part1": 1001,
|
|
"part2": 1002,
|
|
"part3": 1003,
|
|
}
|
|
|
|
names := lo.Keys(partitions)
|
|
ids := lo.Map(names, func(name string, _ int) int64 {
|
|
return partitions[name]
|
|
})
|
|
|
|
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
|
|
Run(func(_ context.Context, req *milvuspb.ShowPartitionsRequest, _ ...grpc.CallOption) {
|
|
s.Equal(dbName, req.GetDbName())
|
|
s.Equal(collName, req.GetCollectionName())
|
|
}).
|
|
Return(&milvuspb.ShowPartitionsResponse{
|
|
Status: merr.Status(nil),
|
|
PartitionIDs: ids,
|
|
PartitionNames: names,
|
|
}, nil)
|
|
partNameIDs, err := s.broker.ShowPartitions(ctx, dbName, collName)
|
|
s.NoError(err)
|
|
s.Equal(len(partitions), len(partNameIDs))
|
|
for name, id := range partitions {
|
|
result, ok := partNameIDs[name]
|
|
s.True(ok)
|
|
s.Equal(id, result)
|
|
}
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_error", func() {
|
|
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
|
|
Return(nil, errors.New("mock"))
|
|
|
|
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("partition_id_name_not_match", func() {
|
|
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
|
|
Return(&milvuspb.ShowPartitionsResponse{
|
|
Status: merr.Status(nil),
|
|
PartitionIDs: []int64{1, 2},
|
|
PartitionNames: []string{"part1"},
|
|
}, nil)
|
|
|
|
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
}
|
|
|
|
func (s *rootCoordSuite) TestAllocTimestamp() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("normal_case", func() {
|
|
num := rand.Intn(10) + 1
|
|
ts := tsoutil.ComposeTSByTime(time.Now(), 0)
|
|
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
|
|
Run(func(_ context.Context, req *rootcoordpb.AllocTimestampRequest, _ ...grpc.CallOption) {
|
|
s.EqualValues(num, req.GetCount())
|
|
}).
|
|
Return(&rootcoordpb.AllocTimestampResponse{
|
|
Status: merr.Status(nil),
|
|
Timestamp: ts,
|
|
Count: uint32(num),
|
|
}, nil)
|
|
|
|
timestamp, cnt, err := s.broker.AllocTimestamp(ctx, uint32(num))
|
|
s.NoError(err)
|
|
s.Equal(ts, timestamp)
|
|
s.EqualValues(num, cnt)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_error", func() {
|
|
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
|
|
Return(nil, errors.New("mock"))
|
|
_, _, err := s.broker.AllocTimestamp(ctx, 1)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_failure_status", func() {
|
|
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
|
|
Return(&rootcoordpb.AllocTimestampResponse{Status: merr.Status(errors.New("mock"))}, nil)
|
|
_, _, err := s.broker.AllocTimestamp(ctx, 1)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
}
|
|
|
|
func (s *rootCoordSuite) TestReportImport() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
taskID := rand.Int63()
|
|
|
|
req := &rootcoordpb.ImportResult{
|
|
Status: merr.Status(nil),
|
|
TaskId: taskID,
|
|
}
|
|
|
|
s.Run("normal_case", func() {
|
|
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
|
|
Run(func(_ context.Context, req *rootcoordpb.ImportResult, _ ...grpc.CallOption) {
|
|
s.Equal(taskID, req.GetTaskId())
|
|
}).
|
|
Return(merr.Status(nil), nil)
|
|
|
|
err := s.broker.ReportImport(ctx, req)
|
|
s.NoError(err)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_error", func() {
|
|
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
|
|
Return(nil, errors.New("mock"))
|
|
|
|
err := s.broker.ReportImport(ctx, req)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
|
|
s.Run("rootcoord_return_failure_status", func() {
|
|
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
|
|
Return(merr.Status(errors.New("mock")), nil)
|
|
|
|
err := s.broker.ReportImport(ctx, req)
|
|
s.Error(err)
|
|
s.resetMock()
|
|
})
|
|
}
|
|
|
|
func TestRootCoordBroker(t *testing.T) {
|
|
suite.Run(t, new(rootCoordSuite))
|
|
}
|