Refine querynode scheduler lifetime (#26915)

This PR refines scheduler lifetime control:
- Move private tri-state into lifetime package
- Make scheduler block incoming "Add" task
- Make scheduler Stop wait until all previously accepted task done

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/27421/head
congqixia 2023-09-28 10:21:26 +08:00 committed by GitHub
parent 8c59dba329
commit 258e1ccd66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 78 deletions

View File

@ -78,20 +78,6 @@ type ShardDelegator interface {
var _ ShardDelegator = (*shardDelegator)(nil)
const (
initializing int32 = iota
working
stopped
)
func notStopped(state int32) bool {
return state != stopped
}
func isWorking(state int32) bool {
return state == working
}
// shardDelegator maintains the shard distribution and streaming part of the data.
type shardDelegator struct {
// shard information attributes
@ -104,7 +90,7 @@ type shardDelegator struct {
workerManager cluster.Manager
lifetime lifetime.Lifetime[int32]
lifetime lifetime.Lifetime[lifetime.State]
distribution *distribution
segmentManager segments.SegmentManager
@ -133,16 +119,16 @@ func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger {
// Serviceable returns whether delegator is serviceable now.
func (sd *shardDelegator) Serviceable() bool {
return sd.lifetime.GetState() == working
return lifetime.IsWorking(sd.lifetime.GetState())
}
func (sd *shardDelegator) Stopped() bool {
return sd.lifetime.GetState() == stopped
return !lifetime.NotStopped(sd.lifetime.GetState())
}
// Start sets delegator to working state.
func (sd *shardDelegator) Start() {
sd.lifetime.SetState(working)
sd.lifetime.SetState(lifetime.Working)
}
// Collection returns delegator collection id.
@ -192,7 +178,7 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
@ -320,7 +306,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
// Query performs query operation on shard.
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
@ -385,7 +371,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
// GetStatistics returns statistics aggregated by delegator.
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
@ -624,7 +610,7 @@ func (sd *shardDelegator) updateTSafe() {
// Close closes the delegator.
func (sd *shardDelegator) Close() {
sd.lifetime.SetState(stopped)
sd.lifetime.SetState(lifetime.Stopped)
sd.lifetime.Close()
// broadcast to all waitTsafe goroutine to quit
sd.tsCond.Broadcast()
@ -659,7 +645,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
collection: collection,
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
pkOracle: pkoracle.NewPkOracle(),
@ -670,7 +656,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}
log.Info("finish build new shardDelegator")

View File

@ -1050,14 +1050,14 @@ func TestDelegatorWatchTsafe(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}
@ -1077,7 +1077,7 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
@ -1085,7 +1085,7 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
signal := make(chan struct{})
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go func() {
sd.watchTSafe()
close(signal)

View File

@ -389,7 +389,7 @@ func (node *QueryNode) Init() error {
// Start mainly start QueryNode's query service.
func (node *QueryNode) Start() error {
node.startOnce.Do(func() {
node.scheduler.Start(node.ctx)
node.scheduler.Start()
paramtable.SetCreateTime(time.Now())
paramtable.SetUpdateTime(time.Now())
@ -453,6 +453,9 @@ func (node *QueryNode) Stop() error {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.lifetime.Wait()
node.cancel()
if node.scheduler != nil {
node.scheduler.Stop()
}
if node.pipelineManager != nil {
node.pipelineManager.Close()
}

View File

@ -1,8 +1,8 @@
package tasks
import (
"context"
"fmt"
"sync"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -11,6 +11,8 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -30,6 +32,7 @@ func newScheduler(policy schedulePolicy) Scheduler {
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
}
@ -44,12 +47,23 @@ type scheduler struct {
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]
// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
// lifetime controls scheduler State & make sure all requests accepted will be processed
lifetime lifetime.Lifetime[lifetime.State]
schedulerCounter
}
// Add a new task into scheduler,
// error will be returned if scheduler reaches some limit.
func (s *scheduler) Add(task Task) (err error) {
if !s.lifetime.Add(lifetime.IsWorking) {
return merr.WrapErrServiceUnavailable("scheduler closed")
}
defer s.lifetime.Done()
errCh := make(chan error, 1)
// TODO: add operation should be fast, is UnsolveLen metric unnesscery?
@ -68,16 +82,31 @@ func (s *scheduler) Add(task Task) (err error) {
// Start schedule the owned task asynchronously and continuously.
// Start should be only call once.
func (s *scheduler) Start(ctx context.Context) {
func (s *scheduler) Start() {
s.wg.Add(2)
// Start a background task executing loop.
go s.exec(ctx)
go s.exec()
// Begin to schedule tasks.
go s.schedule(ctx)
go s.schedule()
s.lifetime.SetState(lifetime.Working)
}
func (s *scheduler) Stop() {
s.lifetime.SetState(lifetime.Stopped)
// wait all accepted Add done
s.lifetime.Wait()
// close receiveChan start stopping process for `schedule`
close(s.receiveChan)
// wait workers quit
s.wg.Wait()
}
// schedule the owned task asynchronously and continuously.
func (s *scheduler) schedule(ctx context.Context) {
func (s *scheduler) schedule() {
defer s.wg.Done()
var task Task
for {
s.setupReadyLenMetric()
@ -87,10 +116,19 @@ func (s *scheduler) schedule(ctx context.Context) {
task, nq, execChan = s.setupExecListener(task)
select {
case <-ctx.Done():
log.Warn("unexpected quit of schedule loop")
return
case req := <-s.receiveChan:
case req, ok := <-s.receiveChan:
if !ok {
log.Info("receiveChan closed, processing remaining request")
// drain policy maintained task
for task != nil {
execChan <- task
s.updateWaitingTaskCounter(-1, -nq)
task = s.produceExecChan()
}
log.Info("all task put into exeChan, schedule worker exit")
close(s.execChan)
return
}
// Receive add operation request and return the process result.
// And consume recv chan as much as possible.
s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum)
@ -166,42 +204,42 @@ func (s *scheduler) produceExecChan() Task {
}
// exec exec the ready task in background continuously.
func (s *scheduler) exec(ctx context.Context) {
func (s *scheduler) exec() {
defer s.wg.Done()
log.Info("start execute loop")
for {
select {
case <-ctx.Done():
log.Warn("unexpected quit of exec loop")
t, ok := <-s.execChan
if !ok {
log.Info("scheduler execChan closed, worker exit")
return
case t := <-s.execChan:
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}
s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
err := t.Execute()
// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)
// Notify task done.
t.Done(err)
return nil, err
})
}
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}
s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
err := t.Execute()
// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)
// Notify task done.
t.Done(err)
return nil, err
})
}
}

View File

@ -21,12 +21,31 @@ func TestScheduler(t *testing.T) {
t.Run("fifo", func(t *testing.T) {
testScheduler(t, newFIFOPolicy())
})
t.Run("scheduler_not_working", func(t *testing.T) {
scheduler := newScheduler(newFIFOPolicy())
task := newMockTask(mockTaskConfig{
nq: 1,
executeCost: 10 * time.Millisecond,
execution: func(ctx context.Context) error {
return nil
},
})
err := scheduler.Add(task)
assert.Error(t, err)
scheduler.Stop()
err = scheduler.Add(task)
assert.Error(t, err)
})
}
func testScheduler(t *testing.T, policy schedulePolicy) {
// start a new scheduler
scheduler := newScheduler(policy)
go scheduler.Start(context.Background())
scheduler.Start()
var cnt atomic.Int32
n := 100

View File

@ -1,9 +1,5 @@
package tasks
import (
"context"
)
const (
schedulePolicyNameFIFO = "fifo"
schedulePolicyNameUserTaskPolling = "user-task-polling"
@ -44,9 +40,12 @@ type Scheduler interface {
Add(task Task) error
// Start schedule the owned task asynchronously and continuously.
// 1. Stop processing until ctx.Cancel() is called.
// 2. Only call once.
Start(ctx context.Context)
// Shall be called only once
Start()
// Stop make scheduler deny all incoming tasks
// and cleans up all related resources
Stop()
// GetWaitingTaskTotalNQ
GetWaitingTaskTotalNQ() int64

View File

@ -0,0 +1,45 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package lifetime
// Singal alias for chan struct{}.
type Signal chan struct{}
// BiState provides pre-defined simple binary state - normal or closed.
type BiState int32
const (
Normal BiState = 0
Closed BiState = 1
)
// State provides pre-defined three stage state.
type State int32
const (
Initializing State = iota
Working
Stopped
)
func NotStopped(state State) bool {
return state != Stopped
}
func IsWorking(state State) bool {
return state == Working
}