Implement detailed lifetime control for querynode (#21851)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/21863/head
congqixia 2023-01-29 17:45:49 +08:00 committed by GitHub
parent 43d633cfed
commit 66027790a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 209 additions and 75 deletions

View File

@ -47,13 +47,11 @@ import (
)
// isHealthy checks if QueryNode is healthy
func (node *QueryNode) isHealthy() bool {
code := node.stateCode.Load().(commonpb.StateCode)
func (node *QueryNode) isHealthy(code commonpb.StateCode) bool {
return code == commonpb.StateCode_Healthy
}
func (node *QueryNode) isHealthyOrStopping() bool {
code := node.stateCode.Load().(commonpb.StateCode)
func (node *QueryNode) isHealthyOrStopping(code commonpb.StateCode) bool {
return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping
}
@ -64,15 +62,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
ErrorCode: commonpb.ErrorCode_Success,
},
}
code, ok := node.stateCode.Load().(commonpb.StateCode)
if !ok {
errMsg := "unexpected error in type assertion"
stats.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: errMsg,
}
return stats, nil
}
code := node.lifetime.GetState()
nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() {
nodeID = node.GetSession().ServerID
@ -83,7 +73,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
StateCode: code,
}
stats.State = info
log.Debug("Get QueryNode component state done", zap.Any("stateCode", info.StateCode))
log.Debug("Get QueryNode component state done", zap.String("stateCode", info.StateCode.String()))
return stats, nil
}
@ -171,12 +161,11 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
},
}
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(node.GetSession().ServerID)
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
log.Ctx(ctx).Debug("received GetStatisticRequest",
@ -301,7 +290,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthy() {
if !node.lifetime.Add(node.isHealthy) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -309,8 +298,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if in.GetBase().GetTargetID() != nodeID {
@ -392,7 +380,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
// check node healthy
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -400,8 +388,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if req.GetBase().GetTargetID() != nodeID {
@ -452,7 +439,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthy() {
if !node.lifetime.Add(node.isHealthy) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -460,8 +447,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if in.GetBase().GetTargetID() != nodeID {
@ -538,7 +524,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
// ReleaseCollection clears all data related to this collection on the querynode
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -546,8 +532,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
dct := &releaseCollectionTask{
baseTask: baseTask{
@ -586,7 +571,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
// ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -594,8 +579,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
dct := &releasePartitionsTask{
baseTask: baseTask{
@ -635,7 +619,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -643,8 +627,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if in.GetBase().GetTargetID() != nodeID {
@ -684,7 +667,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
res := &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
@ -694,8 +677,7 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen
}
return res, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
var segmentInfos []*querypb.SegmentInfo
@ -828,12 +810,11 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
if node.queryShardService == nil {
failRet.Status.Reason = "queryShardService is nil"
@ -979,12 +960,11 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
log.Ctx(ctx).Debug("queryWithDmlChannel receives query request",
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
@ -1197,14 +1177,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
// SyncReplicaSegments syncs replica node & segments states
func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) {
if !node.isHealthy() {
if !node.lifetime.Add(node.isHealthy) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(node.GetSession().ServerID),
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
log.Info("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName()))
@ -1225,7 +1204,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
// ShowConfigurations returns the configurations of queryNode matching req.Pattern
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
log.Warn("QueryNode.ShowConfigurations failed",
zap.Int64("nodeId", nodeID),
zap.String("req", req.Pattern),
@ -1239,8 +1218,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
Configuations: nil,
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
configList := make([]*commonpb.KeyValuePair, 0)
for key, value := range Params.GetComponentConfigurations("querynode", req.Pattern) {
@ -1263,7 +1241,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeId", nodeID),
zap.String("req", req.Request),
@ -1277,8 +1255,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
Response: "",
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
@ -1333,7 +1310,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
zap.Int64("msg-id", req.GetBase().GetMsgID()),
zap.Int64("node-id", nodeID),
)
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
log.Warn("QueryNode.GetMetrics failed",
zap.Error(errQueryNodeIsUnhealthy(nodeID)))
@ -1344,8 +1321,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
},
}, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if req.GetBase().GetTargetID() != nodeID {
@ -1426,7 +1402,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()))
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthyOrStopping() {
if !node.lifetime.Add(node.isHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -1434,8 +1410,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
}
return status, nil
}
node.wg.Add(1)
defer node.wg.Done()
defer node.lifetime.Done()
// check target matches
if req.GetBase().GetTargetID() != nodeID {

View File

@ -22,7 +22,6 @@ import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -49,21 +48,18 @@ func TestImpl_GetComponentStates(t *testing.T) {
assert.NoError(t, err)
node.session.UpdateRegistered(true)
node.UpdateStateCode(commonpb.StateCode_Healthy)
rsp, err := node.GetComponentStates(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
assert.Equal(t, commonpb.StateCode_Healthy, rsp.GetState().GetStateCode())
node.UpdateStateCode(commonpb.StateCode_Abnormal)
rsp, err = node.GetComponentStates(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
node.stateCode = atomic.Value{}
node.stateCode.Store("invalid")
rsp, err = node.GetComponentStates(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode)
assert.Equal(t, commonpb.StateCode_Abnormal, rsp.GetState().GetStateCode())
}
func TestImpl_GetTimeTickChannel(t *testing.T) {
@ -519,8 +515,7 @@ func TestImpl_isHealthy(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
isHealthy := node.isHealthy()
assert.True(t, isHealthy)
assert.True(t, node.isHealthy(node.lifetime.GetState()))
}
func TestImpl_ShowConfigurations(t *testing.T) {

View File

@ -35,7 +35,6 @@ import (
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
@ -50,6 +49,7 @@ import (
"github.com/milvus-io/milvus/internal/util/gc"
"github.com/milvus-io/milvus/internal/util/hardware"
"github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/internal/util/lifetime"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -83,10 +83,9 @@ type QueryNode struct {
queryNodeLoopCtx context.Context
queryNodeLoopCancel context.CancelFunc
wg sync.WaitGroup
lifetime lifetime.Lifetime[commonpb.StateCode]
stateCode atomic.Value
stopOnce sync.Once
stopOnce sync.Once
//call once
initOnce sync.Once
@ -143,11 +142,11 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode {
queryNodeLoopCancel: cancel,
factory: factory,
IsStandAlone: os.Getenv(metricsinfo.DeployModeEnvKey) == metricsinfo.StandaloneDeployMode,
lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal),
}
queryNode.tSafeReplica = newTSafeReplica()
queryNode.scheduler = newTaskScheduler(ctx1, queryNode.tSafeReplica)
queryNode.UpdateStateCode(commonpb.StateCode_Abnormal)
return queryNode
}
@ -355,7 +354,7 @@ func (node *QueryNode) Stop() error {
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.wg.Wait()
node.lifetime.Wait()
node.queryNodeLoopCancel()
// close services
@ -383,7 +382,7 @@ func (node *QueryNode) Stop() error {
// UpdateStateCode updata the state of query node, which can be initializing, healthy, and abnormal
func (node *QueryNode) UpdateStateCode(code commonpb.StateCode) {
node.stateCode.Store(code)
node.lifetime.SetState(code)
}
// SetEtcdClient assigns parameter client to its member etcdCli

View File

@ -0,0 +1,101 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// package lifetime provides common component lifetime control logic.
package lifetime
import (
"sync"
)
// Lifetime interface for lifetime control.
type Lifetime[T any] interface {
// SetState is the method to change lifetime state.
SetState(state T)
// GetState returns current state.
GetState() T
// Add records a task is running, returns false if the lifetime is not healthy.
Add(isHealthy IsHealthy[T]) bool
// Done records a task is done.
Done()
// Wait waits until all tasks are done.
Wait()
}
// IsHealthy function type for lifetime healthy check.
type IsHealthy[T any] func(T) bool
var _ Lifetime[any] = (*lifetime[any])(nil)
// NewLifetime returns a new instance of Lifetime with init state and isHealthy logic.
func NewLifetime[T any](initState T) Lifetime[T] {
return &lifetime[T]{
state: initState,
}
}
// lifetime implementation of Lifetime.
// users shall not care about the internal fields of this struct.
type lifetime[T any] struct {
// wg is used for keeping record each running task.
wg sync.WaitGroup
// state is the "atomic" value to store component state.
state T
// mut is the rwmutex to control each task and state change event.
mut sync.RWMutex
// isHealthy is the method to check whether is legal to add a task.
isHealthy func(int32) bool
}
// SetState is the method to change lifetime state.
func (l *lifetime[T]) SetState(state T) {
l.mut.Lock()
defer l.mut.Unlock()
l.state = state
}
// GetState returns current state.
func (l *lifetime[T]) GetState() T {
l.mut.RLock()
defer l.mut.RUnlock()
return l.state
}
// Add records a task is running, returns false if the lifetime is not healthy.
func (l *lifetime[T]) Add(isHealthy IsHealthy[T]) bool {
l.mut.RLock()
defer l.mut.RUnlock()
// check lifetime healthy
if !isHealthy(l.state) {
return false
}
l.wg.Add(1)
return true
}
// Done records a task is done.
func (l *lifetime[T]) Done() {
l.wg.Done()
}
// Wait waits until all tasks are done.
func (l *lifetime[T]) Wait() {
l.wg.Wait()
}

View File

@ -0,0 +1,64 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package lifetime
import (
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type LifetimeSuite struct {
suite.Suite
}
func (s *LifetimeSuite) TestNormal() {
l := NewLifetime[int32](0)
isHealthy := func(state int32) bool { return state == 0 }
state := l.GetState()
s.EqualValues(0, state)
s.True(l.Add(isHealthy))
l.SetState(1)
s.False(l.Add(isHealthy))
signal := make(chan struct{})
go func() {
l.Wait()
close(signal)
}()
select {
case <-signal:
s.FailNow("signal closed before all tasks done")
default:
}
l.Done()
select {
case <-signal:
case <-time.After(time.Second):
s.FailNow("signal not closed after all tasks done")
}
}
func TestLifetime(t *testing.T) {
suite.Run(t, new(LifetimeSuite))
}