milvus/internal/datacoord/task/global_scheduler.go

330 lines
9.1 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package task
import (
"context"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
taskcommon "github.com/milvus-io/milvus/pkg/v2/taskcommon"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/lock"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const NullNodeID = -1
type GlobalScheduler interface {
Enqueue(task Task)
AbortAndRemoveTask(taskID int64)
Start()
Stop()
}
var _ GlobalScheduler = (*globalTaskScheduler)(nil)
type globalTaskScheduler struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu *lock.KeyLock[int64]
pendingTasks PriorityQueue
runningTasks *typeutil.ConcurrentMap[int64, Task]
execPool *conc.Pool[struct{}]
checkPool *conc.Pool[struct{}]
cluster session.Cluster
}
func (s *globalTaskScheduler) Enqueue(task Task) {
if s.pendingTasks.Get(task.GetTaskID()) != nil {
return
}
if s.runningTasks.Contain(task.GetTaskID()) {
return
}
switch task.GetTaskState() {
case taskcommon.Init:
task.SetTaskTime(taskcommon.TimeQueue, time.Now())
s.pendingTasks.Push(task)
case taskcommon.InProgress, taskcommon.Retry:
task.SetTaskTime(taskcommon.TimeStart, time.Now())
s.runningTasks.Insert(task.GetTaskID(), task)
}
log.Ctx(s.ctx).Info("task enqueued", WrapTaskLog(task)...)
}
func (s *globalTaskScheduler) AbortAndRemoveTask(taskID int64) {
s.mu.Lock(taskID)
defer s.mu.Unlock(taskID)
if task, ok := s.runningTasks.GetAndRemove(taskID); ok {
task.DropTaskOnWorker(s.cluster)
}
if task := s.pendingTasks.Get(taskID); task != nil {
task.DropTaskOnWorker(s.cluster)
s.pendingTasks.Remove(taskID)
}
}
func (s *globalTaskScheduler) Start() {
dur := paramtable.Get().DataCoordCfg.TaskScheduleInterval.GetAsDuration(time.Millisecond)
s.wg.Add(3)
go func() {
defer s.wg.Done()
t := time.NewTicker(dur)
defer t.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-t.C:
s.schedule()
}
}
}()
go func() {
defer s.wg.Done()
t := time.NewTicker(dur)
defer t.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-t.C:
s.check()
}
}
}()
go func() {
defer s.wg.Done()
t := time.NewTicker(time.Minute)
defer t.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-t.C:
s.updateTaskTimeMetrics()
}
}
}()
}
func (s *globalTaskScheduler) Stop() {
s.cancel()
s.wg.Wait()
}
func (s *globalTaskScheduler) pickNode(workerSlots map[int64]*session.WorkerSlots, taskSlot int64) int64 {
var maxAvailable int64 = -1
var nodeID int64 = NullNodeID
for id, ws := range workerSlots {
if ws.AvailableSlots > maxAvailable && ws.AvailableSlots > 0 {
maxAvailable = ws.AvailableSlots
nodeID = id
}
}
if nodeID != NullNodeID {
workerSlots[nodeID].AvailableSlots = 0
return nodeID
}
return NullNodeID
}
func (s *globalTaskScheduler) schedule() {
pendingNum := len(s.pendingTasks.TaskIDs())
if pendingNum == 0 {
return
}
nodeSlots := s.cluster.QuerySlot()
log.Ctx(s.ctx).Info("scheduling pending tasks...", zap.Int("num", pendingNum), zap.Any("nodeSlots", nodeSlots))
futures := make([]*conc.Future[struct{}], 0)
for {
task := s.pendingTasks.Pop()
if task == nil {
break
}
taskSlot := task.GetTaskSlot()
nodeID := s.pickNode(nodeSlots, taskSlot)
if nodeID == NullNodeID {
s.pendingTasks.Push(task)
break
}
future := s.execPool.Submit(func() (struct{}, error) {
s.mu.RLock(task.GetTaskID())
defer s.mu.RUnlock(task.GetTaskID())
log.Ctx(s.ctx).Info("processing task...", WrapTaskLog(task)...)
if task.GetTaskState() == taskcommon.Init {
task.CreateTaskOnWorker(nodeID, s.cluster)
switch task.GetTaskState() {
case taskcommon.Init, taskcommon.Retry:
s.pendingTasks.Push(task)
case taskcommon.InProgress:
task.SetTaskTime(taskcommon.TimeStart, time.Now())
s.runningTasks.Insert(task.GetTaskID(), task)
}
}
return struct{}{}, nil
})
futures = append(futures, future)
}
_ = conc.AwaitAll(futures...)
}
func (s *globalTaskScheduler) check() {
if s.runningTasks.Len() <= 0 {
return
}
log.Ctx(s.ctx).Info("check running tasks", zap.Int("num", s.runningTasks.Len()))
tasks := s.runningTasks.Values()
futures := make([]*conc.Future[struct{}], 0, len(tasks))
for _, task := range tasks {
future := s.checkPool.Submit(func() (struct{}, error) {
s.mu.RLock(task.GetTaskID())
defer s.mu.RUnlock(task.GetTaskID())
task.QueryTaskOnWorker(s.cluster)
switch task.GetTaskState() {
case taskcommon.Init, taskcommon.Retry:
s.runningTasks.Remove(task.GetTaskID())
s.pendingTasks.Push(task)
case taskcommon.Finished, taskcommon.Failed:
task.SetTaskTime(taskcommon.TimeEnd, time.Now())
task.DropTaskOnWorker(s.cluster)
s.runningTasks.Remove(task.GetTaskID())
}
return struct{}{}, nil
})
futures = append(futures, future)
}
_ = conc.AwaitAll(futures...)
}
func (s *globalTaskScheduler) updateTaskTimeMetrics() {
var (
taskNumByTypeAndState = make(map[string]map[string]int64) // taskType => [taskState => taskNum]
maxTaskQueueingTime = make(map[string]int64)
maxTaskRunningTime = make(map[string]int64)
)
for _, taskType := range taskcommon.TypeList {
taskNumByTypeAndState[taskType] = make(map[string]int64)
}
collectPendingMetricsFunc := func(taskID int64) {
task := s.pendingTasks.Get(taskID)
if task == nil {
return
}
s.mu.Lock(taskID)
defer s.mu.Unlock(taskID)
taskType := task.GetTaskType()
queueingTime := time.Since(task.GetTaskTime(taskcommon.TimeQueue))
if queueingTime > paramtable.Get().DataCoordCfg.TaskSlowThreshold.GetAsDuration(time.Second) {
log.Ctx(s.ctx).Warn("task queueing time is too long", zap.Int64("taskID", taskID),
zap.Int64("queueing time(ms)", queueingTime.Milliseconds()))
}
maxQueueingTime, ok := maxTaskQueueingTime[taskType]
if !ok || maxQueueingTime < queueingTime.Milliseconds() {
maxTaskQueueingTime[taskType] = queueingTime.Milliseconds()
}
taskNumByTypeAndState[taskType][task.GetTaskState().String()]++
metrics.TaskVersion.WithLabelValues(taskType).Observe(float64(task.GetTaskVersion()))
}
collectRunningMetricsFunc := func(task Task) {
s.mu.Lock(task.GetTaskID())
defer s.mu.Unlock(task.GetTaskID())
taskType := task.GetTaskType()
runningTime := time.Since(task.GetTaskTime(taskcommon.TimeStart))
if runningTime > paramtable.Get().DataCoordCfg.TaskSlowThreshold.GetAsDuration(time.Second) {
log.Ctx(s.ctx).Warn("task running time is too long", zap.Int64("taskID", task.GetTaskID()),
zap.Int64("running time(ms)", runningTime.Milliseconds()))
}
maxRunningTime, ok := maxTaskRunningTime[taskType]
if !ok || maxRunningTime < runningTime.Milliseconds() {
maxTaskRunningTime[taskType] = runningTime.Milliseconds()
}
taskNumByTypeAndState[taskType][task.GetTaskState().String()]++
}
taskIDs := s.pendingTasks.TaskIDs()
for _, taskID := range taskIDs {
collectPendingMetricsFunc(taskID)
}
allRunningTasks := s.runningTasks.Values()
for _, task := range allRunningTasks {
collectRunningMetricsFunc(task)
}
for taskType, queueingTime := range maxTaskQueueingTime {
metrics.DataCoordTaskExecuteLatency.
WithLabelValues(taskType, metrics.Pending).Observe(float64(queueingTime))
}
for taskType, runningTime := range maxTaskRunningTime {
metrics.DataCoordTaskExecuteLatency.
WithLabelValues(taskType, metrics.Executing).Observe(float64(runningTime))
}
metrics.TaskNumInGlobalScheduler.Reset()
for taskType, taskNumByState := range taskNumByTypeAndState {
for taskState, taskNum := range taskNumByState {
metrics.TaskNumInGlobalScheduler.WithLabelValues(taskType, taskState).Set(float64(taskNum))
}
}
}
func NewGlobalTaskScheduler(ctx context.Context, cluster session.Cluster) GlobalScheduler {
execPool := conc.NewPool[struct{}](128)
checkPool := conc.NewPool[struct{}](128)
ctx1, cancel := context.WithCancel(ctx)
return &globalTaskScheduler{
ctx: ctx1,
cancel: cancel,
wg: sync.WaitGroup{},
mu: lock.NewKeyLock[int64](),
pendingTasks: NewPriorityQueuePolicy(),
runningTasks: typeutil.NewConcurrentMap[int64, Task](),
execPool: execPool,
checkPool: checkPool,
cluster: cluster,
}
}