diff --git a/internal/datacoord/session/datanode_manager.go b/internal/datacoord/session/datanode_manager.go index 132137a924..e65f2cb959 100644 --- a/internal/datacoord/session/datanode_manager.go +++ b/internal/datacoord/session/datanode_manager.go @@ -59,7 +59,7 @@ type DataNodeManager interface { Flush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) FlushChannels(ctx context.Context, nodeID int64, req *datapb.FlushChannelsRequest) error Compaction(ctx context.Context, nodeID int64, plan *datapb.CompactionPlan) error - SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error + SyncSegments(ctx context.Context, nodeID int64, req *datapb.SyncSegmentsRequest) error GetCompactionPlanResult(nodeID int64, planID int64) (*datapb.CompactionPlanResult, error) GetCompactionPlansResults() (map[int64]*typeutil.Pair[int64, *datapb.CompactionPlanResult], error) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error @@ -224,7 +224,7 @@ func (c *DataNodeManagerImpl) Compaction(ctx context.Context, nodeID int64, plan } // SyncSegments is a grpc interface. It will send request to DataNode with provided `nodeID` synchronously. -func (c *DataNodeManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error { +func (c *DataNodeManagerImpl) SyncSegments(ctx context.Context, nodeID int64, req *datapb.SyncSegmentsRequest) error { log := log.With( zap.Int64("nodeID", nodeID), zap.Int64("planID", req.GetPlanID()), @@ -237,9 +237,9 @@ func (c *DataNodeManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegment return err } - err = retry.Do(context.Background(), func() error { + err = retry.Do(ctx, func() error { // doesn't set timeout - resp, err := cli.SyncSegments(context.Background(), req) + resp, err := cli.SyncSegments(ctx, req) if err := merr.CheckRPCCall(resp, err); err != nil { log.Warn("failed to sync segments", zap.Error(err)) return err diff --git a/internal/datacoord/session/mock_datanode_manager.go b/internal/datacoord/session/mock_datanode_manager.go index 60da9d2f2b..9bd42f2847 100644 --- a/internal/datacoord/session/mock_datanode_manager.go +++ b/internal/datacoord/session/mock_datanode_manager.go @@ -1039,17 +1039,17 @@ func (_c *MockDataNodeManager_QuerySlot_Call) RunAndReturn(run func(int64) (*dat return _c } -// SyncSegments provides a mock function with given fields: nodeID, req -func (_m *MockDataNodeManager) SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error { - ret := _m.Called(nodeID, req) +// SyncSegments provides a mock function with given fields: ctx, nodeID, req +func (_m *MockDataNodeManager) SyncSegments(ctx context.Context, nodeID int64, req *datapb.SyncSegmentsRequest) error { + ret := _m.Called(ctx, nodeID, req) if len(ret) == 0 { panic("no return value specified for SyncSegments") } var r0 error - if rf, ok := ret.Get(0).(func(int64, *datapb.SyncSegmentsRequest) error); ok { - r0 = rf(nodeID, req) + if rf, ok := ret.Get(0).(func(context.Context, int64, *datapb.SyncSegmentsRequest) error); ok { + r0 = rf(ctx, nodeID, req) } else { r0 = ret.Error(0) } @@ -1063,15 +1063,16 @@ type MockDataNodeManager_SyncSegments_Call struct { } // SyncSegments is a helper method to define mock.On call +// - ctx context.Context // - nodeID int64 // - req *datapb.SyncSegmentsRequest -func (_e *MockDataNodeManager_Expecter) SyncSegments(nodeID interface{}, req interface{}) *MockDataNodeManager_SyncSegments_Call { - return &MockDataNodeManager_SyncSegments_Call{Call: _e.mock.On("SyncSegments", nodeID, req)} +func (_e *MockDataNodeManager_Expecter) SyncSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockDataNodeManager_SyncSegments_Call { + return &MockDataNodeManager_SyncSegments_Call{Call: _e.mock.On("SyncSegments", ctx, nodeID, req)} } -func (_c *MockDataNodeManager_SyncSegments_Call) Run(run func(nodeID int64, req *datapb.SyncSegmentsRequest)) *MockDataNodeManager_SyncSegments_Call { +func (_c *MockDataNodeManager_SyncSegments_Call) Run(run func(ctx context.Context, nodeID int64, req *datapb.SyncSegmentsRequest)) *MockDataNodeManager_SyncSegments_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(*datapb.SyncSegmentsRequest)) + run(args[0].(context.Context), args[1].(int64), args[2].(*datapb.SyncSegmentsRequest)) }) return _c } @@ -1081,7 +1082,7 @@ func (_c *MockDataNodeManager_SyncSegments_Call) Return(_a0 error) *MockDataNode return _c } -func (_c *MockDataNodeManager_SyncSegments_Call) RunAndReturn(run func(int64, *datapb.SyncSegmentsRequest) error) *MockDataNodeManager_SyncSegments_Call { +func (_c *MockDataNodeManager_SyncSegments_Call) RunAndReturn(run func(context.Context, int64, *datapb.SyncSegmentsRequest) error) *MockDataNodeManager_SyncSegments_Call { _c.Call.Return(run) return _c } diff --git a/internal/datacoord/sync_segments_scheduler.go b/internal/datacoord/sync_segments_scheduler.go index 4f7b5ed8d0..fd626ea333 100644 --- a/internal/datacoord/sync_segments_scheduler.go +++ b/internal/datacoord/sync_segments_scheduler.go @@ -17,6 +17,7 @@ package datacoord import ( + "context" "sync" "time" @@ -31,8 +32,10 @@ import ( ) type SyncSegmentsScheduler struct { - quit chan struct{} - wg sync.WaitGroup + ctx context.Context + cancelFunc context.CancelFunc + quit chan struct{} + wg sync.WaitGroup meta *meta channelManager ChannelManager @@ -40,7 +43,10 @@ type SyncSegmentsScheduler struct { } func newSyncSegmentsScheduler(m *meta, channelManager ChannelManager, sessions session.DataNodeManager) *SyncSegmentsScheduler { + ctx, cancel := context.WithCancel(context.Background()) return &SyncSegmentsScheduler{ + ctx: ctx, + cancelFunc: cancel, quit: make(chan struct{}), wg: sync.WaitGroup{}, meta: m, @@ -65,7 +71,7 @@ func (sss *SyncSegmentsScheduler) Start() { ticker.Stop() return case <-ticker.C: - sss.SyncSegmentsForCollections() + sss.SyncSegmentsForCollections(sss.ctx) } } }() @@ -73,11 +79,12 @@ func (sss *SyncSegmentsScheduler) Start() { } func (sss *SyncSegmentsScheduler) Stop() { + sss.cancelFunc() close(sss.quit) sss.wg.Wait() } -func (sss *SyncSegmentsScheduler) SyncSegmentsForCollections() { +func (sss *SyncSegmentsScheduler) SyncSegmentsForCollections(ctx context.Context) { collIDs := sss.meta.ListCollections() for _, collID := range collIDs { collInfo := sss.meta.GetCollection(collID) @@ -99,7 +106,7 @@ func (sss *SyncSegmentsScheduler) SyncSegmentsForCollections() { continue } for _, partitionID := range collInfo.Partitions { - if err := sss.SyncSegments(collID, partitionID, channelName, nodeID, pkField.GetFieldID()); err != nil { + if err := sss.SyncSegments(ctx, collID, partitionID, channelName, nodeID, pkField.GetFieldID()); err != nil { log.Warn("sync segment with channel failed, retry next ticker", zap.Int64("collectionID", collID), zap.Int64("partitionID", partitionID), @@ -112,7 +119,7 @@ func (sss *SyncSegmentsScheduler) SyncSegmentsForCollections() { } } -func (sss *SyncSegmentsScheduler) SyncSegments(collectionID, partitionID int64, channelName string, nodeID, pkFieldID int64) error { +func (sss *SyncSegmentsScheduler) SyncSegments(ctx context.Context, collectionID, partitionID int64, channelName string, nodeID, pkFieldID int64) error { log := log.With(zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.String("channelName", channelName), zap.Int64("nodeID", nodeID)) // sync all healthy segments, but only check flushed segments on datanode. Because L0 growing segments may not in datacoord's meta. @@ -147,7 +154,7 @@ func (sss *SyncSegmentsScheduler) SyncSegments(collectionID, partitionID int64, } } - if err := sss.sessions.SyncSegments(nodeID, req); err != nil { + if err := sss.sessions.SyncSegments(ctx, nodeID, req); err != nil { log.Warn("fail to sync segments with node", zap.Error(err)) return err } diff --git a/internal/datacoord/sync_segments_scheduler_test.go b/internal/datacoord/sync_segments_scheduler_test.go index f6d321acc9..d9a0043bd2 100644 --- a/internal/datacoord/sync_segments_scheduler_test.go +++ b/internal/datacoord/sync_segments_scheduler_test.go @@ -17,6 +17,7 @@ package datacoord import ( + "context" "sync/atomic" "testing" @@ -323,7 +324,7 @@ func (s *SyncSegmentsSchedulerSuite) Test_newSyncSegmentsScheduler() { cm.EXPECT().FindWatcher(mock.Anything).Return(100, nil) sm := session.NewMockDataNodeManager(s.T()) - sm.EXPECT().SyncSegments(mock.Anything, mock.Anything).RunAndReturn(func(i int64, request *datapb.SyncSegmentsRequest) error { + sm.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64, request *datapb.SyncSegmentsRequest) error { for _, seg := range request.GetSegmentInfos() { if seg.GetState() == commonpb.SegmentState_Flushed { s.new.Add(1) @@ -352,21 +353,22 @@ func (s *SyncSegmentsSchedulerSuite) Test_SyncSegmentsFail() { sm := session.NewMockDataNodeManager(s.T()) sss := newSyncSegmentsScheduler(s.m, cm, sm) + ctx := context.Background() s.Run("pk not found", func() { sss.meta.collections[1].Schema.Fields[0].IsPrimaryKey = false - sss.SyncSegmentsForCollections() + sss.SyncSegmentsForCollections(ctx) sss.meta.collections[1].Schema.Fields[0].IsPrimaryKey = true }) s.Run("find watcher failed", func() { cm.EXPECT().FindWatcher(mock.Anything).Return(0, errors.New("mock error")).Twice() - sss.SyncSegmentsForCollections() + sss.SyncSegmentsForCollections(ctx) }) s.Run("sync segment failed", func() { cm.EXPECT().FindWatcher(mock.Anything).Return(100, nil) - sm.EXPECT().SyncSegments(mock.Anything, mock.Anything).Return(errors.New("mock error")) - sss.SyncSegmentsForCollections() + sm.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")) + sss.SyncSegmentsForCollections(ctx) }) }