milvus/internal/datanode/broker/rootcoord_test.go

242 lines
6.3 KiB
Go

package broker
import (
"context"
"math/rand"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
type rootCoordSuite struct {
suite.Suite
rc *mocks.MockRootCoordClient
broker Broker
}
func (s *rootCoordSuite) SetupSuite() {
paramtable.Init()
}
func (s *rootCoordSuite) SetupTest() {
s.rc = mocks.NewMockRootCoordClient(s.T())
s.broker = NewCoordBroker(s.rc, nil)
}
func (s *rootCoordSuite) resetMock() {
s.rc.AssertExpectations(s.T())
s.rc.ExpectedCalls = nil
}
func (s *rootCoordSuite) TestDescribeCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collectionID := int64(100)
timestamp := tsoutil.ComposeTSByTime(time.Now(), 0)
s.Run("normal_case", func() {
collName := "test_collection_name"
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) {
s.Equal(collectionID, req.GetCollectionID())
s.Equal(timestamp, req.GetTimeStamp())
}).
Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
CollectionID: collectionID,
CollectionName: collName,
}, nil)
resp, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.NoError(err)
s.Equal(collectionID, resp.GetCollectionID())
s.Equal(collName, resp.GetCollectionName())
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(errors.New("mocked")),
}, nil)
_, err := s.broker.DescribeCollection(ctx, collectionID, timestamp)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestShowPartitions() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dbName := "defaultDB"
collName := "testCollection"
s.Run("normal_case", func() {
partitions := map[string]int64{
"part1": 1001,
"part2": 1002,
"part3": 1003,
}
names := lo.Keys(partitions)
ids := lo.Map(names, func(name string, _ int) int64 {
return partitions[name]
})
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *milvuspb.ShowPartitionsRequest, _ ...grpc.CallOption) {
s.Equal(dbName, req.GetDbName())
s.Equal(collName, req.GetCollectionName())
}).
Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: ids,
PartitionNames: names,
}, nil)
partNameIDs, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.NoError(err)
s.Equal(len(partitions), len(partNameIDs))
for name, id := range partitions {
result, ok := partNameIDs[name]
s.True(ok)
s.Equal(id, result)
}
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.Error(err)
s.resetMock()
})
s.Run("partition_id_name_not_match", func() {
s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).
Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: []int64{1, 2},
PartitionNames: []string{"part1"},
}, nil)
_, err := s.broker.ShowPartitions(ctx, dbName, collName)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestAllocTimestamp() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("normal_case", func() {
num := rand.Intn(10) + 1
ts := tsoutil.ComposeTSByTime(time.Now(), 0)
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *rootcoordpb.AllocTimestampRequest, _ ...grpc.CallOption) {
s.EqualValues(num, req.GetCount())
}).
Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Status(nil),
Timestamp: ts,
Count: uint32(num),
}, nil)
timestamp, cnt, err := s.broker.AllocTimestamp(ctx, uint32(num))
s.NoError(err)
s.Equal(ts, timestamp)
s.EqualValues(num, cnt)
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
_, _, err := s.broker.AllocTimestamp(ctx, 1)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).
Return(&rootcoordpb.AllocTimestampResponse{Status: merr.Status(errors.New("mock"))}, nil)
_, _, err := s.broker.AllocTimestamp(ctx, 1)
s.Error(err)
s.resetMock()
})
}
func (s *rootCoordSuite) TestReportImport() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
taskID := rand.Int63()
req := &rootcoordpb.ImportResult{
Status: merr.Status(nil),
TaskId: taskID,
}
s.Run("normal_case", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Run(func(_ context.Context, req *rootcoordpb.ImportResult, _ ...grpc.CallOption) {
s.Equal(taskID, req.GetTaskId())
}).
Return(merr.Status(nil), nil)
err := s.broker.ReportImport(ctx, req)
s.NoError(err)
s.resetMock()
})
s.Run("rootcoord_return_error", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Return(nil, errors.New("mock"))
err := s.broker.ReportImport(ctx, req)
s.Error(err)
s.resetMock()
})
s.Run("rootcoord_return_failure_status", func() {
s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything).
Return(merr.Status(errors.New("mock")), nil)
err := s.broker.ReportImport(ctx, req)
s.Error(err)
s.resetMock()
})
}
func TestRootCoordBroker(t *testing.T) {
suite.Run(t, new(rootCoordSuite))
}