// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package task import ( "context" "sync" "time" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/datacoord/session" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" taskcommon "github.com/milvus-io/milvus/pkg/v2/taskcommon" "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/lock" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) const NullNodeID = -1 type GlobalScheduler interface { Enqueue(task Task) AbortAndRemoveTask(taskID int64) Start() Stop() } var _ GlobalScheduler = (*globalTaskScheduler)(nil) type globalTaskScheduler struct { ctx context.Context cancel context.CancelFunc wg sync.WaitGroup mu *lock.KeyLock[int64] pendingTasks PriorityQueue runningTasks *typeutil.ConcurrentMap[int64, Task] execPool *conc.Pool[struct{}] checkPool *conc.Pool[struct{}] cluster session.Cluster } func (s *globalTaskScheduler) Enqueue(task Task) { if s.pendingTasks.Get(task.GetTaskID()) != nil { return } if s.runningTasks.Contain(task.GetTaskID()) { return } switch task.GetTaskState() { case taskcommon.Init: task.SetTaskTime(taskcommon.TimeQueue, time.Now()) s.pendingTasks.Push(task) case taskcommon.InProgress, taskcommon.Retry: task.SetTaskTime(taskcommon.TimeStart, time.Now()) s.runningTasks.Insert(task.GetTaskID(), task) } log.Ctx(s.ctx).Info("task enqueued", WrapTaskLog(task)...) } func (s *globalTaskScheduler) AbortAndRemoveTask(taskID int64) { s.mu.Lock(taskID) defer s.mu.Unlock(taskID) if task, ok := s.runningTasks.GetAndRemove(taskID); ok { task.DropTaskOnWorker(s.cluster) } if task := s.pendingTasks.Get(taskID); task != nil { task.DropTaskOnWorker(s.cluster) s.pendingTasks.Remove(taskID) } } func (s *globalTaskScheduler) Start() { dur := paramtable.Get().DataCoordCfg.TaskScheduleInterval.GetAsDuration(time.Millisecond) s.wg.Add(3) go func() { defer s.wg.Done() t := time.NewTicker(dur) defer t.Stop() for { select { case <-s.ctx.Done(): return case <-t.C: s.schedule() } } }() go func() { defer s.wg.Done() t := time.NewTicker(dur) defer t.Stop() for { select { case <-s.ctx.Done(): return case <-t.C: s.check() } } }() go func() { defer s.wg.Done() t := time.NewTicker(time.Minute) defer t.Stop() for { select { case <-s.ctx.Done(): return case <-t.C: s.updateTaskTimeMetrics() } } }() } func (s *globalTaskScheduler) Stop() { s.cancel() s.wg.Wait() } func (s *globalTaskScheduler) pickNode(workerSlots map[int64]*session.WorkerSlots, taskSlot int64) int64 { var maxAvailable int64 = -1 var nodeID int64 = NullNodeID for id, ws := range workerSlots { if ws.AvailableSlots > maxAvailable && ws.AvailableSlots > 0 { maxAvailable = ws.AvailableSlots nodeID = id } } if nodeID != NullNodeID { workerSlots[nodeID].AvailableSlots = 0 return nodeID } return NullNodeID } func (s *globalTaskScheduler) schedule() { pendingNum := len(s.pendingTasks.TaskIDs()) if pendingNum == 0 { return } nodeSlots := s.cluster.QuerySlot() log.Ctx(s.ctx).Info("scheduling pending tasks...", zap.Int("num", pendingNum), zap.Any("nodeSlots", nodeSlots)) futures := make([]*conc.Future[struct{}], 0) for { task := s.pendingTasks.Pop() if task == nil { break } taskSlot := task.GetTaskSlot() nodeID := s.pickNode(nodeSlots, taskSlot) if nodeID == NullNodeID { s.pendingTasks.Push(task) break } future := s.execPool.Submit(func() (struct{}, error) { s.mu.RLock(task.GetTaskID()) defer s.mu.RUnlock(task.GetTaskID()) log.Ctx(s.ctx).Info("processing task...", WrapTaskLog(task)...) if task.GetTaskState() == taskcommon.Init { task.CreateTaskOnWorker(nodeID, s.cluster) switch task.GetTaskState() { case taskcommon.Init, taskcommon.Retry: s.pendingTasks.Push(task) case taskcommon.InProgress: task.SetTaskTime(taskcommon.TimeStart, time.Now()) s.runningTasks.Insert(task.GetTaskID(), task) } } return struct{}{}, nil }) futures = append(futures, future) } _ = conc.AwaitAll(futures...) } func (s *globalTaskScheduler) check() { if s.runningTasks.Len() <= 0 { return } log.Ctx(s.ctx).Info("check running tasks", zap.Int("num", s.runningTasks.Len())) tasks := s.runningTasks.Values() futures := make([]*conc.Future[struct{}], 0, len(tasks)) for _, task := range tasks { future := s.checkPool.Submit(func() (struct{}, error) { s.mu.RLock(task.GetTaskID()) defer s.mu.RUnlock(task.GetTaskID()) task.QueryTaskOnWorker(s.cluster) switch task.GetTaskState() { case taskcommon.Init, taskcommon.Retry: s.runningTasks.Remove(task.GetTaskID()) s.pendingTasks.Push(task) case taskcommon.Finished, taskcommon.Failed: task.SetTaskTime(taskcommon.TimeEnd, time.Now()) task.DropTaskOnWorker(s.cluster) s.runningTasks.Remove(task.GetTaskID()) } return struct{}{}, nil }) futures = append(futures, future) } _ = conc.AwaitAll(futures...) } func (s *globalTaskScheduler) updateTaskTimeMetrics() { var ( taskNumByTypeAndState = make(map[string]map[string]int64) // taskType => [taskState => taskNum] maxTaskQueueingTime = make(map[string]int64) maxTaskRunningTime = make(map[string]int64) ) for _, taskType := range taskcommon.TypeList { taskNumByTypeAndState[taskType] = make(map[string]int64) } collectPendingMetricsFunc := func(taskID int64) { task := s.pendingTasks.Get(taskID) if task == nil { return } s.mu.Lock(taskID) defer s.mu.Unlock(taskID) taskType := task.GetTaskType() queueingTime := time.Since(task.GetTaskTime(taskcommon.TimeQueue)) if queueingTime > paramtable.Get().DataCoordCfg.TaskSlowThreshold.GetAsDuration(time.Second) { log.Ctx(s.ctx).Warn("task queueing time is too long", zap.Int64("taskID", taskID), zap.Int64("queueing time(ms)", queueingTime.Milliseconds())) } maxQueueingTime, ok := maxTaskQueueingTime[taskType] if !ok || maxQueueingTime < queueingTime.Milliseconds() { maxTaskQueueingTime[taskType] = queueingTime.Milliseconds() } taskNumByTypeAndState[taskType][task.GetTaskState().String()]++ metrics.TaskVersion.WithLabelValues(taskType).Observe(float64(task.GetTaskVersion())) } collectRunningMetricsFunc := func(task Task) { s.mu.Lock(task.GetTaskID()) defer s.mu.Unlock(task.GetTaskID()) taskType := task.GetTaskType() runningTime := time.Since(task.GetTaskTime(taskcommon.TimeStart)) if runningTime > paramtable.Get().DataCoordCfg.TaskSlowThreshold.GetAsDuration(time.Second) { log.Ctx(s.ctx).Warn("task running time is too long", zap.Int64("taskID", task.GetTaskID()), zap.Int64("running time(ms)", runningTime.Milliseconds())) } maxRunningTime, ok := maxTaskRunningTime[taskType] if !ok || maxRunningTime < runningTime.Milliseconds() { maxTaskRunningTime[taskType] = runningTime.Milliseconds() } taskNumByTypeAndState[taskType][task.GetTaskState().String()]++ } taskIDs := s.pendingTasks.TaskIDs() for _, taskID := range taskIDs { collectPendingMetricsFunc(taskID) } allRunningTasks := s.runningTasks.Values() for _, task := range allRunningTasks { collectRunningMetricsFunc(task) } for taskType, queueingTime := range maxTaskQueueingTime { metrics.DataCoordTaskExecuteLatency. WithLabelValues(taskType, metrics.Pending).Observe(float64(queueingTime)) } for taskType, runningTime := range maxTaskRunningTime { metrics.DataCoordTaskExecuteLatency. WithLabelValues(taskType, metrics.Executing).Observe(float64(runningTime)) } metrics.TaskNumInGlobalScheduler.Reset() for taskType, taskNumByState := range taskNumByTypeAndState { for taskState, taskNum := range taskNumByState { metrics.TaskNumInGlobalScheduler.WithLabelValues(taskType, taskState).Set(float64(taskNum)) } } } func NewGlobalTaskScheduler(ctx context.Context, cluster session.Cluster) GlobalScheduler { execPool := conc.NewPool[struct{}](128) checkPool := conc.NewPool[struct{}](128) ctx1, cancel := context.WithCancel(ctx) return &globalTaskScheduler{ ctx: ctx1, cancel: cancel, wg: sync.WaitGroup{}, mu: lock.NewKeyLock[int64](), pendingTasks: NewPriorityQueuePolicy(), runningTasks: typeutil.NewConcurrentMap[int64, Task](), execPool: execPool, checkPool: checkPool, cluster: cluster, } }