diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 7fbf18e1cf..983986cb1c 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -347,6 +347,12 @@ func (s *Server) Stop() (err error) { defer s.tikvCli.Close() } + if s.rootCoord != nil { + log.Info("graceful stop rootCoord") + s.rootCoord.GracefulStop() + log.Info("graceful stop rootCoord done") + } + if s.grpcServer != nil { utils.GracefulStopGRPCServer(s.grpcServer) } diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index 9e7b5115b7..917fa9d836 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -118,6 +118,9 @@ func (m *mockCore) Stop() error { return fmt.Errorf("stop error") } +func (m *mockCore) GracefulStop() { +} + func TestRun(t *testing.T) { paramtable.Init() parameters := []string{"tikv", "etcd"} diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index a6301b3e9a..0181c2bc55 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -826,15 +826,18 @@ func (c *Core) revokeSession() { } } +func (c *Core) GracefulStop() { + if c.streamingCoord != nil { + c.streamingCoord.Stop() + } +} + // Stop stops rootCoord. func (c *Core) Stop() error { c.UpdateStateCode(commonpb.StateCode_Abnormal) c.stopExecutor() c.stopScheduler() - if c.streamingCoord != nil { - c.streamingCoord.Stop() - } if c.proxyWatcher != nil { c.proxyWatcher.Stop() } diff --git a/internal/streamingcoord/client/broadcast/watcher_resuming.go b/internal/streamingcoord/client/broadcast/watcher_resuming.go index 2f99c238d1..077f3786c5 100644 --- a/internal/streamingcoord/client/broadcast/watcher_resuming.go +++ b/internal/streamingcoord/client/broadcast/watcher_resuming.go @@ -67,7 +67,8 @@ func (r *resumingWatcher) Close() { func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { backoff := typeutil.NewBackoffTimer(backoffConfig) - nextTimer := time.After(0) + var nextTimer <-chan time.Time + var initialized bool var watcher Watcher defer func() { if watcher != nil { @@ -92,6 +93,12 @@ func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { watcher = nil } } + if !initialized { + // try to initialize watcher in next loop. + // avoid to make a grpc stream channel if the watch operation is not used. + nextTimer = time.After(0) + initialized = true + } case ev, ok := <-eventChan: if !ok { watcher.Close() @@ -101,15 +108,15 @@ func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { r.evs.Notify(ev) case <-nextTimer: var err error + nextTimer = nil if watcher, err = r.createNewWatcher(); err != nil { r.Logger().Warn("create new watcher failed", zap.Error(err)) break } r.Logger().Info("create new watcher successful") backoff.DisableBackoff() - nextTimer = nil } - if watcher == nil { + if watcher == nil && nextTimer == nil { backoff.EnableBackoff() var interval time.Duration nextTimer, interval = backoff.NextTimer() diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go index abe35d51ec..98d2bd85bc 100644 --- a/internal/streamingcoord/server/balancer/balancer.go +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -3,11 +3,16 @@ package balancer import ( "context" + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var _ Balancer = (*balancerImpl)(nil) +var ( + _ Balancer = (*balancerImpl)(nil) + ErrBalancerClosed = errors.New("balancer is closed") +) // Balancer is a load balancer to balance the load of log node. // Given the balance result to assign or remove channels to corresponding log node. diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 7a263a2100..7d827121cc 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -30,7 +31,10 @@ func RecoverBalancer( if err != nil { return nil, errors.Wrap(err, "fail to recover channel manager") } + ctx, cancel := context.WithCancelCause(context.Background()) b := &balancerImpl{ + ctx: ctx, + cancel: cancel, lifetime: typeutil.NewLifetime(), logger: resource.Resource().Logger().With(log.FieldComponent("balancer"), zap.String("policy", policy)), channelMetaManager: manager, @@ -44,6 +48,8 @@ func RecoverBalancer( // balancerImpl is a implementation of Balancer. type balancerImpl struct { + ctx context.Context + cancel context.CancelCauseFunc lifetime *typeutil.Lifetime logger *log.MLogger channelMetaManager *channel.ChannelManager @@ -58,6 +64,9 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers return status.NewOnShutdownError("balancer is closing") } defer b.lifetime.Done() + + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.channelMetaManager.WatchAssignmentResult(ctx, cb) } @@ -67,6 +76,8 @@ func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types. } defer b.lifetime.Done() + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.sendRequestAndWaitFinish(ctx, newOpMarkAsUnavailable(ctx, pChannels)) } @@ -77,6 +88,8 @@ func (b *balancerImpl) Trigger(ctx context.Context) error { } defer b.lifetime.Done() + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.sendRequestAndWaitFinish(ctx, newOpTrigger(ctx)) } @@ -93,6 +106,8 @@ func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *req // Close close the balancer. func (b *balancerImpl) Close() { b.lifetime.SetState(typeutil.LifetimeStateStopped) + // cancel all watch opeartion by context. + b.cancel(ErrBalancerClosed) b.lifetime.Wait() b.backgroundTaskNotifier.Cancel() @@ -216,7 +231,7 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo // assign the channel to the target node. if err := resource.Resource().StreamingNodeManagerClient().Assign(ctx, channel.CurrentAssignment()); err != nil { - b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment())) + b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment()), zap.Error(err)) return err } b.logger.Info("assign channel success", zap.Any("assignment", channel.CurrentAssignment())) diff --git a/internal/streamingcoord/server/balancer/balancer_test.go b/internal/streamingcoord/server/balancer/balancer_test.go index b794527ca7..427d993f50 100644 --- a/internal/streamingcoord/server/balancer/balancer_test.go +++ b/internal/streamingcoord/server/balancer/balancer_test.go @@ -3,6 +3,7 @@ package balancer_test import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -16,6 +17,7 @@ import ( "github.com/milvus-io/milvus/pkg/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -91,7 +93,6 @@ func TestBalancer(t *testing.T) { b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair") assert.NoError(t, err) assert.NotNil(t, b) - defer b.Close() b.MarkAsUnavailable(ctx, []types.PChannelInfo{{ Name: "test-channel-1", @@ -113,4 +114,18 @@ func TestBalancer(t *testing.T) { return nil }) assert.ErrorIs(t, err, doneErr) + + // create a inifite block watcher and can be interrupted by close of balancer. + f := syncutil.NewFuture[error]() + go func() { + err := b.WatchChannelAssignments(context.Background(), func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { + return nil + }) + f.Set(err) + }() + time.Sleep(20 * time.Millisecond) + assert.False(t, f.Ready()) + + b.Close() + assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed) } diff --git a/internal/types/types.go b/internal/types/types.go index 2d52a8e6dc..d694bae3c9 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -213,6 +213,8 @@ type RootCoordComponent interface { GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) RegisterStreamingCoordGRPCService(server *grpc.Server) + + GracefulStop() } // ProxyClient is the client interface for proxy server diff --git a/tests/_helm/values/e2e/distributed b/tests/_helm/values/e2e/distributed index a906781c71..e189b0bbe4 100644 --- a/tests/_helm/values/e2e/distributed +++ b/tests/_helm/values/e2e/distributed @@ -18,7 +18,7 @@ dataCoordinator: dataNode: resources: limits: - cpu: "2" + cpu: "1" requests: cpu: "0.5" memory: 500Mi @@ -247,7 +247,21 @@ queryNode: cpu: "2" requests: cpu: "0.5" - memory: 500Mi + memory: 512Mi +streamingNode: + resources: + limits: + cpu: "2" + requests: + cpu: "0.5" + memory: 512Mi +mixCoordinator: + resources: + limits: + cpu: "1" + requests: + cpu: "0.2" + memory: 256Mi rootCoordinator: resources: limits: diff --git a/tests/_helm/values/e2e/distributed-streaming-service b/tests/_helm/values/e2e/distributed-streaming-service index 2175becb18..2b8208009c 100644 --- a/tests/_helm/values/e2e/distributed-streaming-service +++ b/tests/_helm/values/e2e/distributed-streaming-service @@ -20,7 +20,7 @@ dataCoordinator: dataNode: resources: limits: - cpu: "2" + cpu: "1" requests: cpu: "0.5" memory: 500Mi @@ -249,7 +249,21 @@ queryNode: cpu: "2" requests: cpu: "0.5" - memory: 500Mi + memory: 512Mi +streamingNode: + resources: + limits: + cpu: "2" + requests: + cpu: "0.5" + memory: 512Mi +mixCoordinator: + resources: + limits: + cpu: "1" + requests: + cpu: "0.2" + memory: 256Mi rootCoordinator: resources: limits: diff --git a/tests/go_client/testcases/helper/helper.go b/tests/go_client/testcases/helper/helper.go index c2d6fe7b94..ce0171d0b5 100644 --- a/tests/go_client/testcases/helper/helper.go +++ b/tests/go_client/testcases/helper/helper.go @@ -123,6 +123,12 @@ func (chainTask *CollectionPrepare) CreateCollection(ctx context.Context, t *tes common.CheckErr(t, err, true) t.Cleanup(func() { + // The collection will be cleanup after the test + // But some ctx is setted with timeout for only a part of unittest, + // which will cause the drop collection failed with timeout. + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10) + defer cancel() + err := mc.DropCollection(ctx, clientv2.NewDropCollectionOption(schema.CollectionName)) common.CheckErr(t, err, true) })