From 494600c37cf821ebca8149cce675ef45180dac57 Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Fri, 10 Sep 2021 14:24:02 +0800 Subject: [PATCH] Add unit tests for msg.go (#7691) Signed-off-by: Xiangyu Wang --- internal/msgstream/msg_test.go | 177 +++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/internal/msgstream/msg_test.go b/internal/msgstream/msg_test.go index 6a8301bbe9..458a821757 100644 --- a/internal/msgstream/msg_test.go +++ b/internal/msgstream/msg_test.go @@ -16,6 +16,7 @@ import ( "testing" "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" @@ -362,6 +363,93 @@ func TestRetrieveResultMsg_Unmarshal_IllegalParameter(t *testing.T) { assert.Nil(t, tsMsg) } +func TestTimeTickMsg(t *testing.T) { + timeTickMsg := &TimeTickMsg{ + BaseMsg: generateBaseMsg(), + TimeTickMsg: internalpb.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + }, + } + + assert.NotNil(t, timeTickMsg.TraceCtx()) + + ctx := context.Background() + timeTickMsg.SetTraceCtx(ctx) + assert.Equal(t, ctx, timeTickMsg.TraceCtx()) + + assert.Equal(t, int64(1), timeTickMsg.ID()) + assert.Equal(t, commonpb.MsgType_TimeTick, timeTickMsg.Type()) + assert.Equal(t, int64(3), timeTickMsg.SourceID()) + + bytes, err := timeTickMsg.Marshal(timeTickMsg) + assert.Nil(t, err) + + tsMsg, err := timeTickMsg.Unmarshal(bytes) + assert.Nil(t, err) + + timeTickMsg2, ok := tsMsg.(*TimeTickMsg) + assert.True(t, ok) + assert.Equal(t, int64(1), timeTickMsg2.ID()) + assert.Equal(t, commonpb.MsgType_TimeTick, timeTickMsg2.Type()) + assert.Equal(t, int64(3), timeTickMsg2.SourceID()) +} + +func TestTimeTickMsg_Unmarshal_IllegalParameter(t *testing.T) { + timeTickMsg := &TimeTickMsg{} + tsMsg, err := timeTickMsg.Unmarshal(10) + assert.NotNil(t, err) + assert.Nil(t, tsMsg) +} + +func TestSegmentStatisticsMsg(t *testing.T) { + segmentStatisticsMsg := &SegmentStatisticsMsg{ + BaseMsg: generateBaseMsg(), + SegmentStatistics: internalpb.SegmentStatistics{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_SegmentStatistics, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + SegStats: []*internalpb.SegmentStatisticsUpdates{}, + }, + } + + assert.NotNil(t, segmentStatisticsMsg.TraceCtx()) + + ctx := context.Background() + segmentStatisticsMsg.SetTraceCtx(ctx) + assert.Equal(t, ctx, segmentStatisticsMsg.TraceCtx()) + + assert.Equal(t, int64(1), segmentStatisticsMsg.ID()) + assert.Equal(t, commonpb.MsgType_SegmentStatistics, segmentStatisticsMsg.Type()) + assert.Equal(t, int64(3), segmentStatisticsMsg.SourceID()) + + bytes, err := segmentStatisticsMsg.Marshal(segmentStatisticsMsg) + assert.Nil(t, err) + + tsMsg, err := segmentStatisticsMsg.Unmarshal(bytes) + assert.Nil(t, err) + + segmentStatisticsMsg2, ok := tsMsg.(*SegmentStatisticsMsg) + assert.True(t, ok) + assert.Equal(t, int64(1), segmentStatisticsMsg2.ID()) + assert.Equal(t, commonpb.MsgType_SegmentStatistics, segmentStatisticsMsg2.Type()) + assert.Equal(t, int64(3), segmentStatisticsMsg2.SourceID()) +} + +func TestSegmentStatisticsMsg_Unmarshal_IllegalParameter(t *testing.T) { + segmentStatisticsMsg := &SegmentStatisticsMsg{} + tsMsg, err := segmentStatisticsMsg.Unmarshal(10) + assert.NotNil(t, err) + assert.Nil(t, tsMsg) +} + func TestCreateCollectionMsg(t *testing.T) { createCollectionMsg := &CreateCollectionMsg{ BaseMsg: generateBaseMsg(), @@ -558,3 +646,92 @@ func TestDropPartitionMsg_Unmarshal_IllegalParameter(t *testing.T) { assert.NotNil(t, err) assert.Nil(t, tsMsg) } + +func TestLoadBalanceSegmentsMsg(t *testing.T) { + loadBalanceSegmentsMsg := &LoadBalanceSegmentsMsg{ + BaseMsg: generateBaseMsg(), + LoadBalanceSegmentsRequest: internalpb.LoadBalanceSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + SegmentIDs: []int64{}, + }, + } + + assert.NotNil(t, loadBalanceSegmentsMsg.TraceCtx()) + + ctx := context.Background() + loadBalanceSegmentsMsg.SetTraceCtx(ctx) + assert.Equal(t, ctx, loadBalanceSegmentsMsg.TraceCtx()) + + assert.Equal(t, int64(1), loadBalanceSegmentsMsg.ID()) + assert.Equal(t, commonpb.MsgType_LoadBalanceSegments, loadBalanceSegmentsMsg.Type()) + assert.Equal(t, int64(3), loadBalanceSegmentsMsg.SourceID()) + + bytes, err := loadBalanceSegmentsMsg.Marshal(loadBalanceSegmentsMsg) + assert.Nil(t, err) + + tsMsg, err := loadBalanceSegmentsMsg.Unmarshal(bytes) + assert.Nil(t, err) + + loadBalanceSegmentsMsg2, ok := tsMsg.(*LoadBalanceSegmentsMsg) + assert.True(t, ok) + assert.Equal(t, int64(1), loadBalanceSegmentsMsg2.ID()) + assert.Equal(t, commonpb.MsgType_LoadBalanceSegments, loadBalanceSegmentsMsg2.Type()) + assert.Equal(t, int64(3), loadBalanceSegmentsMsg2.SourceID()) +} + +func TestLoadBalanceSegmentsMsg_Unmarshal_IllegalParameter(t *testing.T) { + loadBalanceSegmentsMsg := &LoadBalanceSegmentsMsg{} + tsMsg, err := loadBalanceSegmentsMsg.Unmarshal(10) + assert.NotNil(t, err) + assert.Nil(t, tsMsg) +} + +func TestDataNodeTtMsg(t *testing.T) { + dataNodeTtMsg := &DataNodeTtMsg{ + BaseMsg: generateBaseMsg(), + DataNodeTtMsg: datapb.DataNodeTtMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DataNodeTt, + MsgID: 1, + Timestamp: 2, + SourceID: 3, + }, + ChannelName: "test-channel", + Timestamp: 4, + }, + } + + assert.NotNil(t, dataNodeTtMsg.TraceCtx()) + + ctx := context.Background() + dataNodeTtMsg.SetTraceCtx(ctx) + assert.Equal(t, ctx, dataNodeTtMsg.TraceCtx()) + + assert.Equal(t, int64(1), dataNodeTtMsg.ID()) + assert.Equal(t, commonpb.MsgType_DataNodeTt, dataNodeTtMsg.Type()) + assert.Equal(t, int64(3), dataNodeTtMsg.SourceID()) + + bytes, err := dataNodeTtMsg.Marshal(dataNodeTtMsg) + assert.Nil(t, err) + + tsMsg, err := dataNodeTtMsg.Unmarshal(bytes) + assert.Nil(t, err) + + dataNodeTtMsg2, ok := tsMsg.(*DataNodeTtMsg) + assert.True(t, ok) + assert.Equal(t, int64(1), dataNodeTtMsg2.ID()) + assert.Equal(t, commonpb.MsgType_DataNodeTt, dataNodeTtMsg2.Type()) + assert.Equal(t, int64(3), dataNodeTtMsg2.SourceID()) +} + +func TestDataNodeTtMsg_Unmarshal_IllegalParameter(t *testing.T) { + dataNodeTtMsg := &DataNodeTtMsg{} + tsMsg, err := dataNodeTtMsg.Unmarshal(10) + assert.NotNil(t, err) + assert.Nil(t, tsMsg) +}