mirror of https://github.com/milvus-io/milvus.git
Implement detailed lifetime control for querynode (#21851)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/21863/head
parent
43d633cfed
commit
66027790a2
|
@ -47,13 +47,11 @@ import (
|
|||
)
|
||||
|
||||
// isHealthy checks if QueryNode is healthy
|
||||
func (node *QueryNode) isHealthy() bool {
|
||||
code := node.stateCode.Load().(commonpb.StateCode)
|
||||
func (node *QueryNode) isHealthy(code commonpb.StateCode) bool {
|
||||
return code == commonpb.StateCode_Healthy
|
||||
}
|
||||
|
||||
func (node *QueryNode) isHealthyOrStopping() bool {
|
||||
code := node.stateCode.Load().(commonpb.StateCode)
|
||||
func (node *QueryNode) isHealthyOrStopping(code commonpb.StateCode) bool {
|
||||
return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping
|
||||
}
|
||||
|
||||
|
@ -64,15 +62,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
|
|||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}
|
||||
code, ok := node.stateCode.Load().(commonpb.StateCode)
|
||||
if !ok {
|
||||
errMsg := "unexpected error in type assertion"
|
||||
stats.Status = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
code := node.lifetime.GetState()
|
||||
nodeID := common.NotRegisteredID
|
||||
if node.session != nil && node.session.Registered() {
|
||||
nodeID = node.GetSession().ServerID
|
||||
|
@ -83,7 +73,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
|
|||
StateCode: code,
|
||||
}
|
||||
stats.State = info
|
||||
log.Debug("Get QueryNode component state done", zap.Any("stateCode", info.StateCode))
|
||||
log.Debug("Get QueryNode component state done", zap.String("stateCode", info.StateCode.String()))
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
|
@ -171,12 +161,11 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
|
|||
},
|
||||
}
|
||||
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
failRet.Status.Reason = msgQueryNodeIsUnhealthy(node.GetSession().ServerID)
|
||||
return failRet, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||
log.Ctx(ctx).Debug("received GetStatisticRequest",
|
||||
|
@ -301,7 +290,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
|
|||
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
nodeID := node.GetSession().ServerID
|
||||
// check node healthy
|
||||
if !node.isHealthy() {
|
||||
if !node.lifetime.Add(node.isHealthy) {
|
||||
err := fmt.Errorf("query node %d is not ready", nodeID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -309,8 +298,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if in.GetBase().GetTargetID() != nodeID {
|
||||
|
@ -392,7 +380,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
|||
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
|
||||
// check node healthy
|
||||
nodeID := node.GetSession().ServerID
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", nodeID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -400,8 +388,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if req.GetBase().GetTargetID() != nodeID {
|
||||
|
@ -452,7 +439,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
nodeID := node.GetSession().ServerID
|
||||
// check node healthy
|
||||
if !node.isHealthy() {
|
||||
if !node.lifetime.Add(node.isHealthy) {
|
||||
err := fmt.Errorf("query node %d is not ready", nodeID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -460,8 +447,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if in.GetBase().GetTargetID() != nodeID {
|
||||
|
@ -538,7 +524,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
|
|||
|
||||
// ReleaseCollection clears all data related to this collection on the querynode
|
||||
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -546,8 +532,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
dct := &releaseCollectionTask{
|
||||
baseTask: baseTask{
|
||||
|
@ -586,7 +571,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
|
|||
|
||||
// ReleasePartitions clears all data related to this partition on the querynode
|
||||
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -594,8 +579,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
dct := &releasePartitionsTask{
|
||||
baseTask: baseTask{
|
||||
|
@ -635,7 +619,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
|
|||
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
|
||||
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
nodeID := node.GetSession().ServerID
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", nodeID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -643,8 +627,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if in.GetBase().GetTargetID() != nodeID {
|
||||
|
@ -684,7 +667,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
|
|||
|
||||
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
|
||||
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
|
||||
res := &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -694,8 +677,7 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen
|
|||
}
|
||||
return res, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
var segmentInfos []*querypb.SegmentInfo
|
||||
|
||||
|
@ -828,12 +810,11 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
|
|||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
|
||||
}
|
||||
}()
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
|
||||
return failRet, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
if node.queryShardService == nil {
|
||||
failRet.Status.Reason = "queryShardService is nil"
|
||||
|
@ -979,12 +960,11 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
|
|||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
|
||||
}
|
||||
}()
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
|
||||
return failRet, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
log.Ctx(ctx).Debug("queryWithDmlChannel receives query request",
|
||||
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
|
||||
|
@ -1197,14 +1177,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
|
||||
// SyncReplicaSegments syncs replica node & segments states
|
||||
func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) {
|
||||
if !node.isHealthy() {
|
||||
if !node.lifetime.Add(node.isHealthy) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msgQueryNodeIsUnhealthy(node.GetSession().ServerID),
|
||||
}, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
log.Info("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName()))
|
||||
|
||||
|
@ -1225,7 +1204,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
|
|||
// ShowConfigurations returns the configurations of queryNode matching req.Pattern
|
||||
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
|
||||
nodeID := node.GetSession().ServerID
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
log.Warn("QueryNode.ShowConfigurations failed",
|
||||
zap.Int64("nodeId", nodeID),
|
||||
zap.String("req", req.Pattern),
|
||||
|
@ -1239,8 +1218,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
|
|||
Configuations: nil,
|
||||
}, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
configList := make([]*commonpb.KeyValuePair, 0)
|
||||
for key, value := range Params.GetComponentConfigurations("querynode", req.Pattern) {
|
||||
|
@ -1263,7 +1241,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
|
|||
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
|
||||
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||
nodeID := node.GetSession().ServerID
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
|
||||
zap.Int64("nodeId", nodeID),
|
||||
zap.String("req", req.Request),
|
||||
|
@ -1277,8 +1255,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
|
|||
Response: "",
|
||||
}, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
metricType, err := metricsinfo.ParseMetricType(req.Request)
|
||||
if err != nil {
|
||||
|
@ -1333,7 +1310,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
|
|||
zap.Int64("msg-id", req.GetBase().GetMsgID()),
|
||||
zap.Int64("node-id", nodeID),
|
||||
)
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
log.Warn("QueryNode.GetMetrics failed",
|
||||
zap.Error(errQueryNodeIsUnhealthy(nodeID)))
|
||||
|
||||
|
@ -1344,8 +1321,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
|
|||
},
|
||||
}, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if req.GetBase().GetTargetID() != nodeID {
|
||||
|
@ -1426,7 +1402,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
|
|||
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()))
|
||||
nodeID := node.GetSession().ServerID
|
||||
// check node healthy
|
||||
if !node.isHealthyOrStopping() {
|
||||
if !node.lifetime.Add(node.isHealthyOrStopping) {
|
||||
err := fmt.Errorf("query node %d is not ready", nodeID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -1434,8 +1410,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
|
|||
}
|
||||
return status, nil
|
||||
}
|
||||
node.wg.Add(1)
|
||||
defer node.wg.Done()
|
||||
defer node.lifetime.Done()
|
||||
|
||||
// check target matches
|
||||
if req.GetBase().GetTargetID() != nodeID {
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"math/rand"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
|
@ -49,21 +48,18 @@ func TestImpl_GetComponentStates(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
node.session.UpdateRegistered(true)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
rsp, err := node.GetComponentStates(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.StateCode_Healthy, rsp.GetState().GetStateCode())
|
||||
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
rsp, err = node.GetComponentStates(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
|
||||
|
||||
node.stateCode = atomic.Value{}
|
||||
node.stateCode.Store("invalid")
|
||||
rsp, err = node.GetComponentStates(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.StateCode_Abnormal, rsp.GetState().GetStateCode())
|
||||
}
|
||||
|
||||
func TestImpl_GetTimeTickChannel(t *testing.T) {
|
||||
|
@ -519,8 +515,7 @@ func TestImpl_isHealthy(t *testing.T) {
|
|||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
isHealthy := node.isHealthy()
|
||||
assert.True(t, isHealthy)
|
||||
assert.True(t, node.isHealthy(node.lifetime.GetState()))
|
||||
}
|
||||
|
||||
func TestImpl_ShowConfigurations(t *testing.T) {
|
||||
|
|
|
@ -35,7 +35,6 @@ import (
|
|||
"runtime"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
@ -50,6 +49,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/gc"
|
||||
"github.com/milvus-io/milvus/internal/util/hardware"
|
||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||
"github.com/milvus-io/milvus/internal/util/lifetime"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
|
@ -83,10 +83,9 @@ type QueryNode struct {
|
|||
queryNodeLoopCtx context.Context
|
||||
queryNodeLoopCancel context.CancelFunc
|
||||
|
||||
wg sync.WaitGroup
|
||||
lifetime lifetime.Lifetime[commonpb.StateCode]
|
||||
|
||||
stateCode atomic.Value
|
||||
stopOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
//call once
|
||||
initOnce sync.Once
|
||||
|
@ -143,11 +142,11 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode {
|
|||
queryNodeLoopCancel: cancel,
|
||||
factory: factory,
|
||||
IsStandAlone: os.Getenv(metricsinfo.DeployModeEnvKey) == metricsinfo.StandaloneDeployMode,
|
||||
lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal),
|
||||
}
|
||||
|
||||
queryNode.tSafeReplica = newTSafeReplica()
|
||||
queryNode.scheduler = newTaskScheduler(ctx1, queryNode.tSafeReplica)
|
||||
queryNode.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
|
||||
return queryNode
|
||||
}
|
||||
|
@ -355,7 +354,7 @@ func (node *QueryNode) Stop() error {
|
|||
}
|
||||
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
node.wg.Wait()
|
||||
node.lifetime.Wait()
|
||||
node.queryNodeLoopCancel()
|
||||
|
||||
// close services
|
||||
|
@ -383,7 +382,7 @@ func (node *QueryNode) Stop() error {
|
|||
|
||||
// UpdateStateCode updata the state of query node, which can be initializing, healthy, and abnormal
|
||||
func (node *QueryNode) UpdateStateCode(code commonpb.StateCode) {
|
||||
node.stateCode.Store(code)
|
||||
node.lifetime.SetState(code)
|
||||
}
|
||||
|
||||
// SetEtcdClient assigns parameter client to its member etcdCli
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// package lifetime provides common component lifetime control logic.
|
||||
package lifetime
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Lifetime interface for lifetime control.
|
||||
type Lifetime[T any] interface {
|
||||
// SetState is the method to change lifetime state.
|
||||
SetState(state T)
|
||||
// GetState returns current state.
|
||||
GetState() T
|
||||
// Add records a task is running, returns false if the lifetime is not healthy.
|
||||
Add(isHealthy IsHealthy[T]) bool
|
||||
// Done records a task is done.
|
||||
Done()
|
||||
// Wait waits until all tasks are done.
|
||||
Wait()
|
||||
}
|
||||
|
||||
// IsHealthy function type for lifetime healthy check.
|
||||
type IsHealthy[T any] func(T) bool
|
||||
|
||||
var _ Lifetime[any] = (*lifetime[any])(nil)
|
||||
|
||||
// NewLifetime returns a new instance of Lifetime with init state and isHealthy logic.
|
||||
func NewLifetime[T any](initState T) Lifetime[T] {
|
||||
return &lifetime[T]{
|
||||
state: initState,
|
||||
}
|
||||
}
|
||||
|
||||
// lifetime implementation of Lifetime.
|
||||
// users shall not care about the internal fields of this struct.
|
||||
type lifetime[T any] struct {
|
||||
// wg is used for keeping record each running task.
|
||||
wg sync.WaitGroup
|
||||
// state is the "atomic" value to store component state.
|
||||
state T
|
||||
// mut is the rwmutex to control each task and state change event.
|
||||
mut sync.RWMutex
|
||||
// isHealthy is the method to check whether is legal to add a task.
|
||||
isHealthy func(int32) bool
|
||||
}
|
||||
|
||||
// SetState is the method to change lifetime state.
|
||||
func (l *lifetime[T]) SetState(state T) {
|
||||
l.mut.Lock()
|
||||
defer l.mut.Unlock()
|
||||
|
||||
l.state = state
|
||||
}
|
||||
|
||||
// GetState returns current state.
|
||||
func (l *lifetime[T]) GetState() T {
|
||||
l.mut.RLock()
|
||||
defer l.mut.RUnlock()
|
||||
|
||||
return l.state
|
||||
}
|
||||
|
||||
// Add records a task is running, returns false if the lifetime is not healthy.
|
||||
func (l *lifetime[T]) Add(isHealthy IsHealthy[T]) bool {
|
||||
l.mut.RLock()
|
||||
defer l.mut.RUnlock()
|
||||
|
||||
// check lifetime healthy
|
||||
if !isHealthy(l.state) {
|
||||
return false
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
return true
|
||||
}
|
||||
|
||||
// Done records a task is done.
|
||||
func (l *lifetime[T]) Done() {
|
||||
l.wg.Done()
|
||||
}
|
||||
|
||||
// Wait waits until all tasks are done.
|
||||
func (l *lifetime[T]) Wait() {
|
||||
l.wg.Wait()
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package lifetime
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type LifetimeSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *LifetimeSuite) TestNormal() {
|
||||
l := NewLifetime[int32](0)
|
||||
isHealthy := func(state int32) bool { return state == 0 }
|
||||
|
||||
state := l.GetState()
|
||||
s.EqualValues(0, state)
|
||||
|
||||
s.True(l.Add(isHealthy))
|
||||
|
||||
l.SetState(1)
|
||||
s.False(l.Add(isHealthy))
|
||||
|
||||
signal := make(chan struct{})
|
||||
go func() {
|
||||
l.Wait()
|
||||
close(signal)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-signal:
|
||||
s.FailNow("signal closed before all tasks done")
|
||||
default:
|
||||
}
|
||||
|
||||
l.Done()
|
||||
select {
|
||||
case <-signal:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("signal not closed after all tasks done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLifetime(t *testing.T) {
|
||||
suite.Run(t, new(LifetimeSuite))
|
||||
}
|
Loading…
Reference in New Issue