enhance: [cherry-pick]Use channel manager interface in server_test (… (#32211)

…#31621)

Tidy the following test codes

    - Remove channel in newTestServer
    - Remove newTestServerWithMeta
    - Remove newTestServer2
    - Remove testDataCoordBase
    - Use the same func for handleTTmsg and handleRPCTTmsg

See also: #31620
pr: #31621

---------

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
pull/32249/head
XuanYang-cn 2024-04-15 11:57:20 +08:00 committed by GitHub
parent e50599ba10
commit 0a3a483d02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 515 additions and 806 deletions

View File

@ -92,7 +92,7 @@ func (c *ClusterImpl) Watch(ctx context.Context, ch string, collectionID UniqueI
return c.channelManager.Watch(ctx, &channelMeta{Name: ch, CollectionID: collectionID})
}
// Flush sends flush requests to dataNodes specified
// Flush sends async FlushSegments requests to dataNodes
// which also according to channels where segments are assigned to.
func (c *ClusterImpl) Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
if !c.channelManager.Match(nodeID, channel) {

View File

@ -56,7 +56,7 @@ func (m *mockMetricIndexNodeClient) GetMetrics(ctx context.Context, req *milvusp
}
func TestGetDataNodeMetrics(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
ctx := context.Background()
@ -123,7 +123,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
}
func TestGetIndexNodeMetrics(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
ctx := context.Background()

View File

@ -53,14 +53,12 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/expr"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/logutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -666,11 +664,12 @@ func (s *Server) initIndexNodeManager() {
}
func (s *Server) startServerLoop() {
s.serverLoopWg.Add(2)
if !Params.DataNodeCfg.DataNodeTimeTickByRPC.GetAsBool() {
s.serverLoopWg.Add(1)
s.startDataNodeTtLoop(s.serverLoopCtx)
}
s.serverLoopWg.Add(2)
s.startWatchService(s.serverLoopCtx)
s.startFlushLoop(s.serverLoopCtx)
s.startIndexService(s.serverLoopCtx)
@ -743,7 +742,7 @@ func (s *Server) handleDataNodeTimetickMsgstream(ctx context.Context, ttMsgStrea
checker.Check()
}
if err := s.handleTimetickMessage(ctx, ttMsg); err != nil {
if err := s.handleDataNodeTtMsg(ctx, &ttMsg.DataNodeTtMsg); err != nil {
log.Warn("failed to handle timetick message", zap.Error(err))
continue
}
@ -753,62 +752,6 @@ func (s *Server) handleDataNodeTimetickMsgstream(ctx context.Context, ttMsgStrea
}
}
func (s *Server) handleTimetickMessage(ctx context.Context, ttMsg *msgstream.DataNodeTtMsg) error {
log := log.Ctx(ctx).WithRateGroup("dc.handleTimetick", 1, 60)
ch := ttMsg.GetChannelName()
ts := ttMsg.GetTimestamp()
physical, _ := tsoutil.ParseTS(ts)
if time.Since(physical).Minutes() > 1 {
// if lag behind, log every 1 mins about
log.RatedWarn(60.0, "time tick lag behind for more than 1 minutes", zap.String("channel", ch), zap.Time("timetick", physical))
}
// ignore report from a different node
if !s.channelManager.Match(ttMsg.GetBase().GetSourceID(), ch) {
log.Warn("node is not matched with channel", zap.String("channel", ch), zap.Int64("nodeID", ttMsg.GetBase().GetSourceID()))
return nil
}
sub := tsoutil.SubByNow(ts)
pChannelName := funcutil.ToPhysicalChannel(ch)
metrics.DataCoordConsumeDataNodeTimeTickLag.
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), pChannelName).
Set(float64(sub))
s.updateSegmentStatistics(ttMsg.GetSegmentsStats())
if err := s.segmentManager.ExpireAllocations(ch, ts); err != nil {
return fmt.Errorf("expire allocations: %w", err)
}
flushableIDs, err := s.segmentManager.GetFlushableSegments(ctx, ch, ts)
if err != nil {
return fmt.Errorf("get flushable segments: %w", err)
}
flushableSegments := s.getFlushableSegmentsInfo(flushableIDs)
if len(flushableSegments) == 0 {
return nil
}
log.Info("start flushing segments",
zap.Int64s("segment IDs", flushableIDs))
// update segment last update triggered time
// it's ok to fail flushing, since next timetick after duration will re-trigger
s.setLastFlushTime(flushableSegments)
finfo := make([]*datapb.SegmentInfo, 0, len(flushableSegments))
for _, info := range flushableSegments {
finfo = append(finfo, info.SegmentInfo)
}
err = s.cluster.Flush(s.ctx, ttMsg.GetBase().GetSourceID(), ch, finfo)
if err != nil {
log.Warn("failed to handle flush", zap.Int64("source", ttMsg.GetBase().GetSourceID()), zap.Error(err))
return err
}
return nil
}
func (s *Server) updateSegmentStatistics(stats []*commonpb.SegmentStats) {
for _, stat := range stats {
segment := s.meta.GetSegment(stat.GetSegmentID())

File diff suppressed because it is too large Load Diff

View File

@ -189,9 +189,6 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
log.Warn("cannot get collection schema", zap.Error(err))
}
// Add the channel to cluster for watching.
s.cluster.Watch(ctx, r.ChannelName, r.CollectionID)
// Have segment manager allocate and return the segment allocation info.
segmentAllocations, err := s.segmentManager.AllocSegment(ctx,
r.CollectionID, r.PartitionID, r.ChannelName, int64(r.Count))
@ -1413,7 +1410,7 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update
return merr.Success(), nil
}
// ReportDataNodeTtMsgs send datenode timetick messages to dataCoord.
// ReportDataNodeTtMsgs gets timetick messages from datanode.
func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx)
if err := merr.CheckHealthy(s.GetStateCode()); err != nil {
@ -1425,7 +1422,7 @@ func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDat
metrics.DataCoordConsumeDataNodeTimeTickLag.
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), ttMsg.GetChannelName()).
Set(float64(sub))
err := s.handleRPCTimetickMessage(ctx, ttMsg)
err := s.handleDataNodeTtMsg(ctx, ttMsg)
if err != nil {
log.Error("fail to handle Datanode Timetick Msg",
zap.Int64("sourceID", ttMsg.GetBase().GetSourceID()),
@ -1438,49 +1435,64 @@ func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDat
return merr.Success(), nil
}
func (s *Server) handleRPCTimetickMessage(ctx context.Context, ttMsg *msgpb.DataNodeTtMsg) error {
log := log.Ctx(ctx)
ch := ttMsg.GetChannelName()
ts := ttMsg.GetTimestamp()
func (s *Server) handleDataNodeTtMsg(ctx context.Context, ttMsg *msgpb.DataNodeTtMsg) error {
var (
channel = ttMsg.GetChannelName()
ts = ttMsg.GetTimestamp()
sourceID = ttMsg.GetBase().GetSourceID()
segmentStats = ttMsg.GetSegmentsStats()
)
// ignore to handle RPC Timetick message since it's no longer the leader
if !s.channelManager.Match(ttMsg.GetBase().GetSourceID(), ch) {
log.Warn("node is not matched with channel",
zap.String("channelName", ch),
zap.Int64("nodeID", ttMsg.GetBase().GetSourceID()),
)
physical, _ := tsoutil.ParseTS(ts)
log := log.Ctx(ctx).WithRateGroup("dc.handleTimetick", 1, 60).With(
zap.String("channel", channel),
zap.Int64("sourceID", sourceID),
zap.Any("ts", ts),
)
if time.Since(physical).Minutes() > 1 {
// if lag behind, log every 1 mins about
log.RatedWarn(60.0, "time tick lag behind for more than 1 minutes")
}
// ignore report from a different node
if !s.channelManager.Match(sourceID, channel) {
log.Warn("node is not matched with channel")
return nil
}
s.updateSegmentStatistics(ttMsg.GetSegmentsStats())
sub := tsoutil.SubByNow(ts)
pChannelName := funcutil.ToPhysicalChannel(channel)
metrics.DataCoordConsumeDataNodeTimeTickLag.
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), pChannelName).
Set(float64(sub))
if err := s.segmentManager.ExpireAllocations(ch, ts); err != nil {
return fmt.Errorf("expire allocations: %w", err)
s.updateSegmentStatistics(segmentStats)
if err := s.segmentManager.ExpireAllocations(channel, ts); err != nil {
log.Warn("failed to expire allocations", zap.Error(err))
return err
}
flushableIDs, err := s.segmentManager.GetFlushableSegments(ctx, ch, ts)
flushableIDs, err := s.segmentManager.GetFlushableSegments(ctx, channel, ts)
if err != nil {
return fmt.Errorf("get flushable segments: %w", err)
log.Warn("failed to get flushable segments", zap.Error(err))
return err
}
flushableSegments := s.getFlushableSegmentsInfo(flushableIDs)
if len(flushableSegments) == 0 {
return nil
}
log.Info("start flushing segments",
zap.Int64s("segment IDs", flushableIDs))
log.Info("start flushing segments", zap.Int64s("segmentIDs", flushableIDs))
// update segment last update triggered time
// it's ok to fail flushing, since next timetick after duration will re-trigger
s.setLastFlushTime(flushableSegments)
finfo := make([]*datapb.SegmentInfo, 0, len(flushableSegments))
for _, info := range flushableSegments {
finfo = append(finfo, info.SegmentInfo)
}
err = s.cluster.Flush(s.ctx, ttMsg.GetBase().GetSourceID(), ch, finfo)
infos := lo.Map(flushableSegments, func(info *SegmentInfo, _ int) *datapb.SegmentInfo {
return info.SegmentInfo
})
err = s.cluster.Flush(s.ctx, sourceID, channel, infos)
if err != nil {
log.Warn("failed to handle flush", zap.Any("source", ttMsg.GetBase().GetSourceID()), zap.Error(err))
log.Warn("failed to call Flush", zap.Error(err))
return err
}

View File

@ -2,6 +2,7 @@ package datacoord
import (
"context"
"fmt"
"testing"
"time"
@ -27,9 +28,11 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
type ServerSuite struct {
@ -39,17 +42,21 @@ type ServerSuite struct {
mockChMgr *MockChannelManager
}
func WithChannelManager(cm ChannelManager) Option {
return func(svr *Server) {
svr.channelManager = cm
}
}
func (s *ServerSuite) SetupTest() {
s.testServer = newTestServer(s.T(), nil)
s.mockChMgr = NewMockChannelManager(s.T())
s.mockChMgr.EXPECT().Startup(mock.Anything, mock.Anything).Return(nil).Maybe()
s.mockChMgr.EXPECT().Close().Maybe()
s.testServer = newTestServer(s.T(), WithChannelManager(s.mockChMgr))
if s.testServer.channelManager != nil {
s.testServer.channelManager.Close()
}
s.mockChMgr = NewMockChannelManager(s.T())
s.testServer.channelManager = s.mockChMgr
if s.mockChMgr != nil {
s.mockChMgr.EXPECT().Close().Maybe()
}
}
func (s *ServerSuite) TearDownTest() {
@ -63,6 +70,327 @@ func TestServerSuite(t *testing.T) {
suite.Run(t, new(ServerSuite))
}
func genMsg(msgType commonpb.MsgType, ch string, t Timestamp, sourceID int64) *msgstream.DataNodeTtMsg {
return &msgstream.DataNodeTtMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
DataNodeTtMsg: msgpb.DataNodeTtMsg{
Base: &commonpb.MsgBase{
MsgType: msgType,
Timestamp: t,
SourceID: sourceID,
},
ChannelName: ch,
Timestamp: t,
SegmentsStats: []*commonpb.SegmentStats{{SegmentID: 2, NumRows: 100}},
},
}
}
func (s *ServerSuite) TestHandleDataNodeTtMsg() {
var (
chanName = "ch-1"
collID int64 = 100
sourceID int64 = 1
)
s.testServer.meta.AddCollection(&collectionInfo{
ID: collID,
Schema: newTestSchema(),
Partitions: []int64{10},
})
resp, err := s.testServer.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
NodeID: sourceID,
SegmentIDRequests: []*datapb.SegmentIDRequest{
{
CollectionID: collID,
PartitionID: 10,
ChannelName: chanName,
Count: 100,
},
},
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp.GetStatus()))
s.Equal(1, len(resp.GetSegIDAssignments()))
assign := resp.GetSegIDAssignments()[0]
assignedSegmentID := resp.SegIDAssignments[0].SegID
segment := s.testServer.meta.GetHealthySegment(assignedSegmentID)
s.Require().NotNil(segment)
s.Equal(1, len(segment.allocations))
ts := tsoutil.AddPhysicalDurationOnTs(assign.ExpireTime, -3*time.Minute)
msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, ts, sourceID)
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
SegmentID: assign.GetSegID(),
NumRows: 1,
})
mockCluster := NewMockCluster(s.T())
mockCluster.EXPECT().Close().Once()
mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn(
func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
s.EqualValues(chanName, channel)
s.EqualValues(sourceID, nodeID)
s.Equal(1, len(segments))
s.EqualValues(2, segments[0].GetID())
return fmt.Errorf("mock error")
}).Once()
s.testServer.cluster = mockCluster
s.mockChMgr.EXPECT().Match(sourceID, chanName).Return(true).Twice()
err = s.testServer.handleDataNodeTtMsg(context.TODO(), &msg.DataNodeTtMsg)
s.NoError(err)
tt := tsoutil.AddPhysicalDurationOnTs(assign.ExpireTime, 48*time.Hour)
msg = genMsg(commonpb.MsgType_DataNodeTt, chanName, tt, sourceID)
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
SegmentID: assign.GetSegID(),
NumRows: 1,
})
err = s.testServer.handleDataNodeTtMsg(context.TODO(), &msg.DataNodeTtMsg)
s.Error(err)
}
// restart the server for config DataNodeTimeTickByRPC=false
func (s *ServerSuite) initSuiteForTtChannel() {
s.testServer.serverLoopWg.Add(1)
s.testServer.startDataNodeTtLoop(s.testServer.serverLoopCtx)
s.testServer.meta.AddCollection(&collectionInfo{
ID: 1,
Schema: newTestSchema(),
Partitions: []int64{10},
})
}
func (s *ServerSuite) TestDataNodeTtChannel_ExpireAfterTt() {
s.initSuiteForTtChannel()
ctx := context.TODO()
ttMsgStream, err := s.testServer.factory.NewMsgStream(ctx)
s.Require().NoError(err)
ttMsgStream.AsProducer([]string{paramtable.Get().CommonCfg.DataCoordTimeTick.GetValue()})
defer ttMsgStream.Close()
var (
sourceID int64 = 9997
chanName = "ch-1"
signal = make(chan struct{})
collID int64 = 1
)
mockCluster := NewMockCluster(s.T())
mockCluster.EXPECT().Close().Once()
mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn(
func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
s.EqualValues(chanName, channel)
s.EqualValues(sourceID, nodeID)
s.Equal(1, len(segments))
s.EqualValues(2, segments[0].GetID())
signal <- struct{}{}
return nil
}).Once()
s.testServer.cluster = mockCluster
s.mockChMgr.EXPECT().Match(sourceID, chanName).Return(true).Once()
resp, err := s.testServer.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
NodeID: sourceID,
SegmentIDRequests: []*datapb.SegmentIDRequest{
{
CollectionID: collID,
PartitionID: 10,
ChannelName: chanName,
Count: 100,
},
},
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp.GetStatus()))
s.Equal(1, len(resp.GetSegIDAssignments()))
assignedSegmentID := resp.SegIDAssignments[0].SegID
segment := s.testServer.meta.GetHealthySegment(assignedSegmentID)
s.Require().NotNil(segment)
s.Equal(1, len(segment.allocations))
msgPack := msgstream.MsgPack{}
tt := tsoutil.AddPhysicalDurationOnTs(resp.SegIDAssignments[0].ExpireTime, 48*time.Hour)
msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", tt, sourceID)
msgPack.Msgs = append(msgPack.Msgs, msg)
err = ttMsgStream.Produce(&msgPack)
s.Require().NoError(err)
<-signal
segment = s.testServer.meta.GetHealthySegment(assignedSegmentID)
s.NotNil(segment)
s.Equal(0, len(segment.allocations))
}
func (s *ServerSuite) TestDataNodeTtChannel_FlushWithDiffChan() {
s.initSuiteForTtChannel()
ctx := context.TODO()
ttMsgStream, err := s.testServer.factory.NewMsgStream(ctx)
s.Require().NoError(err)
ttMsgStream.AsProducer([]string{paramtable.Get().CommonCfg.DataCoordTimeTick.GetValue()})
defer ttMsgStream.Close()
var (
sourceID int64 = 9998
chanName = "ch-1"
signal = make(chan struct{})
collID int64 = 1
)
mockCluster := NewMockCluster(s.T())
mockCluster.EXPECT().Close().Once()
mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn(
func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
s.EqualValues(chanName, channel)
s.EqualValues(sourceID, nodeID)
s.Equal(1, len(segments))
signal <- struct{}{}
return nil
}).Once()
mockCluster.EXPECT().FlushChannels(mock.Anything, sourceID, mock.Anything, []string{chanName}).Return(nil).Once()
s.testServer.cluster = mockCluster
s.mockChMgr.EXPECT().Match(sourceID, chanName).Return(true).Once()
s.mockChMgr.EXPECT().GetNodeChannelsByCollectionID(collID).Return(map[int64][]string{
sourceID: {chanName},
})
resp, err := s.testServer.AssignSegmentID(ctx, &datapb.AssignSegmentIDRequest{
NodeID: sourceID,
SegmentIDRequests: []*datapb.SegmentIDRequest{
{
CollectionID: collID,
PartitionID: 10,
ChannelName: chanName,
Count: 100,
},
{
CollectionID: collID,
PartitionID: 10,
ChannelName: "ch-2",
Count: 100,
},
},
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp.GetStatus()))
s.Equal(2, len(resp.GetSegIDAssignments()))
var assign *datapb.SegmentIDAssignment
for _, segment := range resp.SegIDAssignments {
if segment.GetChannelName() == chanName {
assign = segment
break
}
}
s.Require().NotNil(assign)
resp2, err := s.testServer.Flush(ctx, &datapb.FlushRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
SourceID: sourceID,
},
CollectionID: collID,
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp2.GetStatus()))
msgPack := msgstream.MsgPack{}
msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, assign.ExpireTime, sourceID)
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
SegmentID: assign.GetSegID(),
NumRows: 1,
})
msgPack.Msgs = append(msgPack.Msgs, msg)
err = ttMsgStream.Produce(&msgPack)
s.NoError(err)
<-signal
}
func (s *ServerSuite) TestDataNodeTtChannel_SegmentFlushAfterTt() {
s.initSuiteForTtChannel()
var (
sourceID int64 = 9999
chanName = "ch-1"
signal = make(chan struct{})
collID int64 = 1
)
mockCluster := NewMockCluster(s.T())
mockCluster.EXPECT().Close().Once()
mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn(
func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error {
s.EqualValues(chanName, channel)
s.EqualValues(sourceID, nodeID)
s.Equal(1, len(segments))
signal <- struct{}{}
return nil
}).Once()
mockCluster.EXPECT().FlushChannels(mock.Anything, sourceID, mock.Anything, []string{chanName}).Return(nil).Once()
s.testServer.cluster = mockCluster
s.mockChMgr.EXPECT().Match(sourceID, chanName).Return(true).Once()
s.mockChMgr.EXPECT().GetNodeChannelsByCollectionID(collID).Return(map[int64][]string{
sourceID: {chanName},
})
ctx := context.TODO()
ttMsgStream, err := s.testServer.factory.NewMsgStream(ctx)
s.Require().NoError(err)
ttMsgStream.AsProducer([]string{paramtable.Get().CommonCfg.DataCoordTimeTick.GetValue()})
defer ttMsgStream.Close()
resp, err := s.testServer.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
SegmentIDRequests: []*datapb.SegmentIDRequest{
{
CollectionID: 1,
PartitionID: 10,
ChannelName: chanName,
Count: 100,
},
},
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp.GetStatus()))
s.Require().Equal(1, len(resp.GetSegIDAssignments()))
assign := resp.GetSegIDAssignments()[0]
resp2, err := s.testServer.Flush(ctx, &datapb.FlushRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
},
CollectionID: 1,
})
s.Require().NoError(err)
s.Require().True(merr.Ok(resp2.GetStatus()))
msgPack := msgstream.MsgPack{}
msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime, 9999)
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
SegmentID: assign.GetSegID(),
NumRows: 1,
})
msgPack.Msgs = append(msgPack.Msgs, msg)
err = ttMsgStream.Produce(&msgPack)
s.Require().NoError(err)
<-signal
}
func (s *ServerSuite) TestGetFlushState_ByFlushTs() {
s.mockChMgr.EXPECT().GetChannelsByCollectionID(int64(0)).
Return([]RWChannel{&channelMeta{Name: "ch1", CollectionID: 0}}).Times(3)
@ -706,7 +1034,7 @@ func TestServer_GcConfirm(t *testing.T) {
func TestGetRecoveryInfoV2(t *testing.T) {
t.Run("test get recovery info with no segments", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -748,7 +1076,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
}
t.Run("test get earliest position of flushed segments as seek position", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -856,7 +1184,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
t.Run("test get recovery of unflushed segments ", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -933,7 +1261,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
t.Run("test get binlogs", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{
@ -1031,7 +1359,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
assert.EqualValues(t, 0, len(resp.GetSegments()[0].GetBinlogs()))
})
t.Run("with dropped segments", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -1076,7 +1404,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
t.Run("with fake segments", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -1120,7 +1448,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
t.Run("with continuous compaction", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context) (types.RootCoordClient, error) {
@ -1205,7 +1533,7 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t, nil)
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetRecoveryInfoV2(context.TODO(), &datapb.GetRecoveryInfoRequestV2{})
assert.NoError(t, err)
@ -1408,7 +1736,7 @@ type GcControlServiceSuite struct {
}
func (s *GcControlServiceSuite) SetupTest() {
s.server = newTestServer(s.T(), nil)
s.server = newTestServer(s.T())
}
func (s *GcControlServiceSuite) TearDownTest() {

View File

@ -171,9 +171,10 @@ func (c *SessionManagerImpl) Flush(ctx context.Context, nodeID int64, req *datap
}
func (c *SessionManagerImpl) execFlush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) {
log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID), zap.String("channel", req.GetChannelName()))
cli, err := c.getClient(ctx, nodeID)
if err != nil {
log.Warn("failed to get dataNode client", zap.Int64("dataNode ID", nodeID), zap.Error(err))
log.Warn("failed to get dataNode client", zap.Error(err))
return
}
ctx, cancel := context.WithTimeout(ctx, flushTimeout)
@ -181,9 +182,9 @@ func (c *SessionManagerImpl) execFlush(ctx context.Context, nodeID int64, req *d
resp, err := cli.FlushSegments(ctx, req)
if err := VerifyResponse(resp, err); err != nil {
log.Error("flush call (perhaps partially) failed", zap.Int64("dataNode ID", nodeID), zap.Error(err))
log.Error("flush call (perhaps partially) failed", zap.Error(err))
} else {
log.Info("flush call succeeded", zap.Int64("dataNode ID", nodeID))
log.Info("flush call succeeded")
}
}

View File

@ -25,7 +25,7 @@ type SessionManagerSuite struct {
dn *mocks.MockDataNodeClient
m SessionManager
m *SessionManagerImpl
}
func (s *SessionManagerSuite) SetupTest() {
@ -39,6 +39,35 @@ func (s *SessionManagerSuite) SetupTest() {
s.MetricsEqual(metrics.DataCoordNumDataNodes, 1)
}
func (s *SessionManagerSuite) SetupSubTest() {
s.SetupTest()
}
func (s *SessionManagerSuite) TestExecFlush() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := &datapb.FlushSegmentsRequest{
CollectionID: 1,
SegmentIDs: []int64{100, 200},
ChannelName: "ch-1",
}
s.Run("no node", func() {
s.m.execFlush(ctx, 100, req)
})
s.Run("fail", func() {
s.dn.EXPECT().FlushSegments(mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Once()
s.m.execFlush(ctx, 1000, req)
})
s.Run("normal", func() {
s.dn.EXPECT().FlushSegments(mock.Anything, mock.Anything).Return(merr.Status(nil), nil).Once()
s.m.execFlush(ctx, 1000, req)
})
}
func (s *SessionManagerSuite) TestNotifyChannelOperation() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -58,16 +87,14 @@ func (s *SessionManagerSuite) TestNotifyChannelOperation() {
})
s.Run("fail", func() {
s.SetupTest()
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Once()
err := s.m.NotifyChannelOperation(ctx, 1000, req)
s.Error(err)
})
s.Run("normal", func() {
s.SetupTest()
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(merr.Status(nil), nil)
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(merr.Status(nil), nil).Once()
err := s.m.NotifyChannelOperation(ctx, 1000, req)
s.NoError(err)
@ -91,8 +118,7 @@ func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() {
})
s.Run("fail", func() {
s.SetupTest()
s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Once()
resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info)
s.Error(err)
@ -100,16 +126,13 @@ func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() {
})
s.Run("normal", func() {
s.SetupTest()
s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
Return(
&datapb.ChannelOperationProgressResponse{
Status: merr.Status(nil),
OpID: info.OpID,
State: info.State,
Progress: 100,
},
nil)
Return(&datapb.ChannelOperationProgressResponse{
Status: merr.Status(nil),
OpID: info.OpID,
State: info.State,
Progress: 100,
}, nil).Once()
resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info)
s.NoError(err)