From 3d8f0156c739c374428bd132e760e26cb6b3ed28 Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 16 Mar 2023 17:43:55 +0800 Subject: [PATCH] Refine scheduler & executor of QueryCoord (#22761) Signed-off-by: yah01 --- internal/querycoordv2/checkers/controller.go | 2 +- internal/querycoordv2/handlers.go | 2 +- internal/querycoordv2/services_test.go | 10 +- internal/querycoordv2/task/executor.go | 77 +++---- internal/querycoordv2/task/merger.go | 9 - internal/querycoordv2/task/scheduler.go | 205 ++++++------------- internal/querycoordv2/task/task.go | 23 ++- internal/querycoordv2/task/task_test.go | 144 +------------ internal/querycoordv2/task/utils.go | 4 +- internal/util/merr/errors.go | 12 +- internal/util/merr/errors_test.go | 4 + internal/util/merr/utils.go | 55 ++++- 12 files changed, 186 insertions(+), 361 deletions(-) diff --git a/internal/querycoordv2/checkers/controller.go b/internal/querycoordv2/checkers/controller.go index bd4cff3dde..935067b40b 100644 --- a/internal/querycoordv2/checkers/controller.go +++ b/internal/querycoordv2/checkers/controller.go @@ -124,7 +124,7 @@ func (controller *CheckerController) check(ctx context.Context) { for _, task := range tasks { err := controller.scheduler.Add(task) if err != nil { - task.Cancel() + task.Cancel(err) continue } } diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 56da270ae0..8200c7bf6b 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -161,7 +161,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe } err = s.taskScheduler.Add(task) if err != nil { - task.Cancel() + task.Cancel(err) return err } tasks = append(tasks, task) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index d2154bd2cc..b76b05d830 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/kv" @@ -1186,7 +1187,7 @@ func (suite *ServiceSuite) TestLoadBalance() { growAction, reduceAction := actions[0], actions[1] suite.Equal(dstNode, growAction.Node()) suite.Equal(srcNode, reduceAction.Node()) - task.Cancel() + task.Cancel(nil) }).Return(nil) resp, err := server.LoadBalance(ctx, req) suite.NoError(err) @@ -1262,7 +1263,7 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { suite.True(lo.Contains(segmentOnCollection[collection], reduceAction.SegmentID())) suite.Equal(dstNode, growAction.Node()) suite.Equal(srcNode, reduceAction.Node()) - t.Cancel() + t.Cancel(nil) }).Return(nil) resp, err := server.LoadBalance(ctx, req) suite.NoError(err) @@ -1352,14 +1353,13 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { SealedSegmentIDs: segments, } suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(balanceTask task.Task) { - balanceTask.SetErr(task.ErrTaskCanceled) - balanceTask.Cancel() + balanceTask.Cancel(errors.New("mock error")) }).Return(nil) resp, err := server.LoadBalance(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) suite.Contains(resp.Reason, "failed to balance segments") - suite.Contains(resp.Reason, task.ErrTaskCanceled.Error()) + suite.Contains(resp.Reason, "mock error") suite.meta.ReplicaManager.AddNode(replicas[0].ID, 10) req.SourceNodeIDs = []int64{10} diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index f754b02935..b5204f6bca 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -21,8 +21,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -147,13 +145,14 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) { task := mergeTask.tasks[0] action := task.Actions()[mergeTask.steps[0]] + var err error defer func() { - canceled := task.canceled.Load() - for i := range mergeTask.tasks { - mergeTask.tasks[i].SetErr(task.Err()) - if canceled { - mergeTask.tasks[i].Cancel() + if err != nil { + for i := range mergeTask.tasks { + mergeTask.tasks[i].Cancel(err) } + } + for i := range mergeTask.tasks { ex.removeTask(mergeTask.tasks[i], mergeTask.steps[i]) } }() @@ -178,31 +177,25 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) { channel := mergeTask.req.GetInfos()[0].GetInsertChannel() leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), channel) if !ok { - msg := "no shard leader for the segment to execute loading" - task.SetErr(utils.WrapError(msg, ErrTaskStale)) - log.Warn(msg, zap.String("shard", channel)) + err = merr.WrapErrChannelNotFound(channel, "shard delegator not found") + log.Warn("no shard leader for the segment to execute loading", zap.Error(task.Err())) return } log.Info("load segments...") status, err := ex.cluster.LoadSegments(task.Context(), leader, mergeTask.req) if err != nil { - log.Warn("failed to load segment, it may be a false failure", zap.Error(err)) + log.Warn("failed to load segment", zap.Error(err)) return } - err = merr.Error(status) - if errors.Is(err, merr.ErrServiceMemoryLimitExceeded) { - log.Warn("insufficient memory to load segment", zap.Error(err)) - task.SetErr(err) - task.Cancel() - return - } - if status.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("failed to load segment", zap.String("reason", status.GetReason())) + if !merr.Ok(status) { + err = merr.Error(status) + log.Warn("failed to load segment", zap.Error(err)) return } + elapsed := time.Since(startTs) - log.Info("load segments done", zap.Int64("taskID", task.ID()), zap.Duration("timeTaken", elapsed)) + log.Info("load segments done", zap.Duration("elapsed", elapsed)) } func (ex *Executor) removeTask(task Task, step int) { @@ -243,8 +236,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { var err error defer func() { if err != nil { - task.SetErr(err) - task.Cancel() + task.Cancel(err) ex.removeTask(task, step) } }() @@ -253,7 +245,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { schema, err := ex.broker.GetCollectionSchema(ctx, task.CollectionID()) if err != nil { log.Warn("failed to get schema of collection", zap.Error(err)) - task.SetErr(err) return err } partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID()) @@ -283,8 +274,8 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), segment.GetInsertChannel()) if !ok { msg := "no shard leader for the segment to execute loading" - err = utils.WrapError(msg, ErrTaskStale) - log.Warn(msg, zap.String("shard", segment.GetInsertChannel())) + err = merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "shard delegator not found") + log.Warn(msg, zap.Error(err)) return err } log = log.With(zap.Int64("shardLeader", leader)) @@ -386,8 +377,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { var err error defer func() { if err != nil { - task.SetErr(err) - task.Cancel() + task.Cancel(err) } }() @@ -413,12 +403,12 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { if dmChannel == nil { msg := "channel does not exist in next target, skip it" log.Warn(msg, zap.String("channelName", action.ChannelName())) - return errors.New(msg) + return merr.WrapErrChannelReduplicate(action.ChannelName()) } - req := packSubDmChannelRequest(task, action, schema, loadMeta, dmChannel) - err = fillSubDmChannelRequest(ctx, req, ex.broker) + req := packSubChannelRequest(task, action, schema, loadMeta, dmChannel) + err = fillSubChannelRequest(ctx, req, ex.broker) if err != nil { - log.Warn("failed to subscribe DmChannel, failed to fill the request with segments", + log.Warn("failed to subscribe channel, failed to fill the request with segments", zap.Error(err)) return err } @@ -430,16 +420,16 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { ) status, err := ex.cluster.WatchDmChannels(ctx, action.Node(), req) if err != nil { - log.Warn("failed to subscribe DmChannel, it may be a false failure", zap.Error(err)) + log.Warn("failed to subscribe channel, it may be a false failure", zap.Error(err)) return err } - if status.ErrorCode != commonpb.ErrorCode_Success { - err = utils.WrapError("failed to subscribe DmChannel", ErrFailedResponse) - log.Warn("failed to subscribe DmChannel", zap.String("reason", status.GetReason())) + if !merr.Ok(status) { + err = merr.Error(status) + log.Warn("failed to subscribe channel", zap.Error(err)) return err } elapsed := time.Since(startTs) - log.Info("subscribe DmChannel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed)) + log.Info("subscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed)) return nil } @@ -459,8 +449,7 @@ func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) error { var err error defer func() { if err != nil { - task.SetErr(err) - task.Cancel() + task.Cancel(err) } }() @@ -470,16 +459,16 @@ func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) error { log.Info("unsubscribe channel...") status, err := ex.cluster.UnsubDmChannel(ctx, action.Node(), req) if err != nil { - log.Warn("failed to unsubscribe DmChannel, it may be a false failure", zap.Error(err)) + log.Warn("failed to unsubscribe channel, it may be a false failure", zap.Error(err)) return err } - if status.ErrorCode != commonpb.ErrorCode_Success { - err = utils.WrapError("failed to unsubscribe DmChannel", ErrFailedResponse) - log.Warn("failed to unsubscribe DmChannel", zap.String("reason", status.GetReason())) + if !merr.Ok(status) { + err = merr.Error(status) + log.Warn("failed to unsubscribe channel", zap.Error(err)) return err } elapsed := time.Since(startTs) - log.Info("unsubscribe DmChannel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed)) + log.Info("unsubscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed)) return nil } diff --git a/internal/querycoordv2/task/merger.go b/internal/querycoordv2/task/merger.go index cd59ef8916..8c14731b2b 100644 --- a/internal/querycoordv2/task/merger.go +++ b/internal/querycoordv2/task/merger.go @@ -91,15 +91,6 @@ func (merger *Merger[K, R]) schedule(ctx context.Context) { }() } -func (merger *Merger[K, R]) isStopped() bool { - select { - case <-merger.stopCh: - return true - default: - return false - } -} - func (merger *Merger[K, R]) Add(task MergeableTask[K, R]) { merger.waitQueue <- task } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index b25d617cfc..64ba330755 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -21,8 +21,7 @@ import ( "fmt" "runtime" "sync" - - "github.com/cockroachdb/errors" + "time" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" @@ -30,8 +29,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" - "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/merr" . "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/atomic" "go.uber.org/zap" @@ -43,24 +42,6 @@ const ( TaskTypeMove ) -var ( - ErrConflictTaskExisted = errors.New("ConflictTaskExisted") - - // The task is canceled or timeout - ErrTaskCanceled = errors.New("TaskCanceled") - - // The target node is offline, - // or the target segment is not in TargetManager, - // or the target channel is not in TargetManager - ErrTaskStale = errors.New("TaskStale") - - // ErrInsufficientMemory returns insufficient memory error. - ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad") - - ErrFailedResponse = errors.New("RpcFailed") - ErrTaskAlreadyDone = errors.New("TaskAlreadyDone") -) - type Type = int32 type replicaSegmentIndex struct { @@ -167,7 +148,7 @@ func NewScheduler(ctx context.Context, broker meta.Broker, cluster session.Cluster, nodeMgr *session.NodeManager) *taskScheduler { - id := int64(0) + id := time.Now().UnixMilli() return &taskScheduler{ ctx: ctx, executors: make(map[int64]*Executor), @@ -276,20 +257,12 @@ func (scheduler *taskScheduler) preAdd(task Task) error { zap.Int64("newID", task.ID()), zap.Int32("newPriority", task.Priority()), ) - old.SetStatus(TaskStatusCanceled) - old.SetErr(utils.WrapError("replaced with the other one with higher priority", ErrTaskCanceled)) + old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority")) scheduler.remove(old) return nil } - return ErrConflictTaskExisted - } - if GetTaskType(task) == TaskTypeGrow { - nodesWithSegment := scheduler.distMgr.LeaderViewManager.GetSealedSegmentDist(task.SegmentID()) - replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithSegment) - if _, ok := replicaNodeMap[task.ReplicaID()]; ok { - return ErrTaskAlreadyDone - } + return merr.WrapErrServiceInternal("task with the same segment exists") } case *ChannelTask: @@ -302,21 +275,12 @@ func (scheduler *taskScheduler) preAdd(task Task) error { zap.Int64("newID", task.ID()), zap.Int32("newPriority", task.Priority()), ) - old.SetStatus(TaskStatusCanceled) - old.SetErr(utils.WrapError("replaced with the other one with higher priority", ErrTaskCanceled)) + old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority")) scheduler.remove(old) return nil } - return ErrConflictTaskExisted - } - - if GetTaskType(task) == TaskTypeGrow { - nodesWithChannel := scheduler.distMgr.LeaderViewManager.GetChannelDist(task.Channel()) - replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel) - if _, ok := replicaNodeMap[task.ReplicaID()]; ok { - return ErrTaskAlreadyDone - } + return merr.WrapErrServiceInternal("task with the same channel exists") } default: @@ -326,43 +290,19 @@ func (scheduler *taskScheduler) preAdd(task Task) error { return nil } -func (scheduler *taskScheduler) promote(task Task) error { - log := log.With( - zap.Int64("taskID", task.ID()), - zap.Int64("collectionID", task.CollectionID()), - zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), - ) - err := scheduler.prePromote(task) - if err != nil { - log.Info("failed to promote task", zap.Error(err)) - return err - } - - scheduler.processQueue.Add(task) - task.SetStatus(TaskStatusStarted) - return nil -} - func (scheduler *taskScheduler) tryPromoteAll() { // Promote waiting tasks toPromote := make([]Task, 0, scheduler.waitQueue.Len()) toRemove := make([]Task, 0) scheduler.waitQueue.Range(func(task Task) bool { err := scheduler.promote(task) - if err != nil { - task.SetStatus(TaskStatusCanceled) - if errors.Is(err, ErrTaskStale) { // Task canceled or stale - task.SetStatus(TaskStatusStale) - } - + task.Cancel(err) + toRemove = append(toRemove, task) log.Warn("failed to promote task", zap.Int64("taskID", task.ID()), zap.Error(err), ) - task.SetErr(err) - toRemove = append(toRemove, task) } else { toPromote = append(toPromote, task) } @@ -384,13 +324,21 @@ func (scheduler *taskScheduler) tryPromoteAll() { } } -func (scheduler *taskScheduler) prePromote(task Task) error { - if scheduler.checkCanceled(task) { - return ErrTaskCanceled - } else if scheduler.checkStale(task) { - return ErrTaskStale +func (scheduler *taskScheduler) promote(task Task) error { + log := log.With( + zap.Int64("taskID", task.ID()), + zap.Int64("collectionID", task.CollectionID()), + zap.Int64("replicaID", task.ReplicaID()), + zap.Int64("source", task.SourceID()), + ) + + if err := scheduler.check(task); err != nil { + log.Info("failed to promote task", zap.Error(err)) + return err } + scheduler.processQueue.Add(task) + task.SetStatus(TaskStatusStarted) return nil } @@ -400,6 +348,8 @@ func (scheduler *taskScheduler) Dispatch(node int64) { log.Info("scheduler stopped") default: + scheduler.rwmutex.Lock() + defer scheduler.rwmutex.Unlock() scheduler.schedule(node) } } @@ -458,13 +408,10 @@ func (scheduler *taskScheduler) GetNodeSegmentCntDelta(nodeID int64) int { } // schedule selects some tasks to execute, follow these steps for each started selected tasks: -// 1. check whether this task is stale, set status to failed if stale +// 1. check whether this task is stale, set status to canceled if stale // 2. step up the task's actions, set status to succeeded if all actions finished // 3. execute the current action of task func (scheduler *taskScheduler) schedule(node int64) { - scheduler.rwmutex.Lock() - defer scheduler.rwmutex.Unlock() - if scheduler.tasks.Len() == 0 { return } @@ -486,7 +433,7 @@ func (scheduler *taskScheduler) schedule(node int64) { toProcess := make([]Task, 0) toRemove := make([]Task, 0) scheduler.processQueue.Range(func(task Task) bool { - if scheduler.isRelated(task, node) && scheduler.preProcess(task) { + if scheduler.preProcess(task) && scheduler.isRelated(task, node) { toProcess = append(toProcess, task) } if task.Status() != TaskStatusStarted { @@ -561,13 +508,9 @@ func (scheduler *taskScheduler) isRelated(task Task, node int64) bool { // return true if the task should be executed, // false otherwise func (scheduler *taskScheduler) preProcess(task Task) bool { - log := log.With( - zap.Int64("taskID", task.ID()), - zap.Int32("type", GetTaskType(task)), - zap.Int64("collectionID", task.CollectionID()), - zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), - ) + if task.Status() != TaskStatusStarted { + return false + } actions, step := task.Actions(), task.Step() for step < len(actions) && actions[step].IsFinished(scheduler.distMgr) { @@ -575,31 +518,12 @@ func (scheduler *taskScheduler) preProcess(task Task) bool { step++ } - if step == len(actions) { - step-- - } - - executor, ok := scheduler.executors[actions[step].Node()] - if !ok { - log.Warn("no executor for QueryNode", - zap.Int("step", step), - zap.Int64("nodeID", actions[step].Node())) - return false - } - if task.IsFinished(scheduler.distMgr) { - if !executor.Exist(task.ID()) { - task.SetStatus(TaskStatusSucceeded) + task.SetStatus(TaskStatusSucceeded) + } else { + if err := scheduler.check(task); err != nil { + task.Cancel(err) } - return false - } else if scheduler.checkCanceled(task) { - task.SetStatus(TaskStatusCanceled) - if task.Err() == nil { - task.SetErr(ErrTaskCanceled) - } - } else if scheduler.checkStale(task) { - task.SetStatus(TaskStatusStale) - task.SetErr(ErrTaskStale) } return task.Status() == TaskStatusStarted @@ -633,7 +557,7 @@ func (scheduler *taskScheduler) process(task Task) bool { case TaskStatusSucceeded: log.Info("task succeeded") - case TaskStatusCanceled, TaskStatusStale: + case TaskStatusCanceled: log.Warn("failed to execute task", zap.Error(task.Err())) default: @@ -643,6 +567,15 @@ func (scheduler *taskScheduler) process(task Task) bool { return false } +func (scheduler *taskScheduler) check(task Task) error { + err := task.Context().Err() + if err == nil { + err = scheduler.checkStale(task) + } + + return err +} + func (scheduler *taskScheduler) RemoveByNode(node int64) { scheduler.rwmutex.Lock() defer scheduler.rwmutex.Unlock() @@ -670,7 +603,7 @@ func (scheduler *taskScheduler) remove(task Task) { zap.Int64("replicaID", task.ReplicaID()), zap.Int32("taskStatus", task.Status()), ) - task.Cancel() + task.Cancel(nil) scheduler.tasks.Remove(task.ID()) scheduler.waitQueue.Remove(task) scheduler.processQueue.Remove(task) @@ -692,28 +625,10 @@ func (scheduler *taskScheduler) remove(task Task) { } metrics.QueryCoordTaskNum.WithLabelValues().Set(float64(scheduler.tasks.Len())) - log.Debug("task removed") + log.Debug("task removed", zap.Stack("stack")) } -func (scheduler *taskScheduler) checkCanceled(task Task) bool { - log := log.With( - zap.Int64("taskID", task.ID()), - zap.Int64("collectionID", task.CollectionID()), - zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), - ) - - select { - case <-task.Context().Done(): - log.Warn("the task is timeout or canceled") - return true - - default: - return false - } -} - -func (scheduler *taskScheduler) checkStale(task Task) bool { +func (scheduler *taskScheduler) checkStale(task Task) error { log := log.With( zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), @@ -723,13 +638,13 @@ func (scheduler *taskScheduler) checkStale(task Task) bool { switch task := task.(type) { case *SegmentTask: - if scheduler.checkSegmentTaskStale(task) { - return true + if err := scheduler.checkSegmentTaskStale(task); err != nil { + return err } case *ChannelTask: - if scheduler.checkChannelTaskStale(task) { - return true + if err := scheduler.checkChannelTaskStale(task); err != nil { + return err } default: @@ -743,14 +658,14 @@ func (scheduler *taskScheduler) checkStale(task Task) bool { if scheduler.nodeMgr.Get(action.Node()) == nil { log.Warn("the task is stale, the target node is offline") - return true + return merr.WrapErrNodeNotFound(action.Node()) } } - return false + return nil } -func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool { +func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error { log := log.With( zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), @@ -773,18 +688,18 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool { zap.Int64("segment", task.segmentID), zap.Int32("taskType", taskType), ) - return true + return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment") } replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) if replica == nil { log.Warn("task stale due to replica not found") - return true + return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID") } _, ok := scheduler.distMgr.GetShardLeader(replica, segment.GetInsertChannel()) if !ok { log.Warn("task stale due to leader not found") - return true + return merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "failed to get shard delegator") } case ActionTypeReduce: @@ -792,10 +707,10 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool { // the task should succeeded if the segment not exists } } - return false + return nil } -func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool { +func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error { log := log.With( zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), @@ -809,7 +724,7 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool { if scheduler.targetMgr.GetDmChannel(task.collectionID, task.Channel(), meta.NextTarget) == nil { log.Warn("the task is stale, the channel to subscribe not exists in targets", zap.String("channel", task.Channel())) - return true + return merr.WrapErrChannelReduplicate(task.Channel(), "target doesn't contain this channel") } case ActionTypeReduce: @@ -817,5 +732,5 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool { // the task should succeeded if the channel not exists } } - return false + return nil } diff --git a/internal/querycoordv2/task/task.go b/internal/querycoordv2/task/task.go index 780e06545a..1f8f7a788c 100644 --- a/internal/querycoordv2/task/task.go +++ b/internal/querycoordv2/task/task.go @@ -37,7 +37,6 @@ const ( TaskStatusStarted TaskStatusSucceeded TaskStatusCanceled - TaskStatusStale ) const ( @@ -61,11 +60,10 @@ type Task interface { Status() Status SetStatus(status Status) Err() error - SetErr(err error) Priority() Priority SetPriority(priority Priority) - Cancel() + Cancel(err error) Wait() error Actions() []Action Step() int @@ -159,16 +157,21 @@ func (task *baseTask) SetPriority(priority Priority) { } func (task *baseTask) Err() error { - return task.err + select { + case <-task.doneCh: + return task.err + default: + return nil + } } -func (task *baseTask) SetErr(err error) { - task.err = err -} - -func (task *baseTask) Cancel() { - if task.canceled.CAS(false, true) { +func (task *baseTask) Cancel(err error) { + if task.canceled.CompareAndSwap(false, true) { task.cancel() + if task.Status() != TaskStatusSucceeded { + task.SetStatus(TaskStatusCanceled) + } + task.err = err close(task.doneCh) } } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 5d5a9ec1a1..2d00eb5224 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -174,48 +174,6 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) { } } -func (suite *TaskSuite) TestSubmitDuplicateSubscribeChannelTask() { - ctx := context.Background() - timeout := 10 * time.Second - targetNode := int64(3) - - tasks := []Task{} - dmChannels := make([]*datapb.VchannelInfo, 0) - for _, channel := range suite.subChannels { - dmChannels = append(dmChannels, &datapb.VchannelInfo{ - CollectionID: suite.collection, - ChannelName: channel, - UnflushedSegmentIds: []int64{suite.growingSegments[channel]}, - }) - task, err := NewChannelTask( - ctx, - timeout, - 0, - suite.collection, - suite.replica, - NewChannelAction(targetNode, ActionTypeGrow, channel), - ) - suite.NoError(err) - tasks = append(tasks, task) - } - - views := make([]*meta.LeaderView, 0) - for _, channel := range suite.subChannels { - views = append(views, &meta.LeaderView{ - ID: targetNode, - CollectionID: suite.collection, - Channel: channel, - }) - } - suite.dist.LeaderViewManager.Update(targetNode, views...) - - for _, task := range tasks { - err := suite.scheduler.Add(task) - suite.Error(err) - suite.ErrorIs(err, ErrTaskAlreadyDone) - } -} - func (suite *TaskSuite) TestSubscribeChannelTask() { ctx := context.Background() timeout := 10 * time.Second @@ -284,10 +242,6 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(len(suite.subChannels), 0, len(suite.subChannels), 0) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(len(suite.subChannels), 0, len(suite.subChannels), 0) - // Process tasks done // Dist contains channels views := make([]*meta.LeaderView, 0) @@ -354,10 +308,6 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(1, 0, 1, 0) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(1, 0, 1, 0) - // Update dist suite.dist.LeaderViewManager.Update(targetNode) suite.dispatchAndWait(targetNode) @@ -433,10 +383,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks done // Dist contains channels view := &meta.LeaderView{ @@ -457,47 +403,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() { } } -func (suite *TaskSuite) TestSubmitDuplicateLoadSegmentTask() { - ctx := context.Background() - timeout := 10 * time.Second - targetNode := int64(3) - channel := &datapb.VchannelInfo{ - CollectionID: suite.collection, - ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test", - } - - tasks := []Task{} - for _, segment := range suite.loadSegments { - task, err := NewSegmentTask( - ctx, - timeout, - 0, - suite.collection, - suite.replica, - NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), - ) - suite.NoError(err) - tasks = append(tasks, task) - } - - // Process tasks done - // Dist contains channels - view := &meta.LeaderView{ - ID: targetNode, - CollectionID: suite.collection, - Segments: map[int64]*querypb.SegmentDist{}, - } - for _, segment := range suite.loadSegments { - view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0} - } - suite.dist.LeaderViewManager.Update(targetNode, view) - - for _, task := range tasks { - err := suite.scheduler.Add(task) - suite.Error(err) - suite.ErrorIs(err, ErrTaskAlreadyDone) - } -} func (suite *TaskSuite) TestLoadSegmentTaskFailed() { ctx := context.Background() timeout := 10 * time.Second @@ -559,10 +464,6 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks done // Dist contains channels time.Sleep(timeout) @@ -629,10 +530,6 @@ func (suite *TaskSuite) TestReleaseSegmentTask() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks done suite.dist.LeaderViewManager.Update(targetNode) suite.dispatchAndWait(targetNode) @@ -684,10 +581,6 @@ func (suite *TaskSuite) TestReleaseGrowingSegmentTask() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum-1, 0, 0, segmentsNum-1) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum-1, 0, 0, segmentsNum-1) - // Release done suite.dist.LeaderViewManager.Update(targetNode) @@ -784,10 +677,6 @@ func (suite *TaskSuite) TestMoveSegmentTask() { suite.dispatchAndWait(leader) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(-1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks, target node contains the segment view = view.Clone() for _, segment := range suite.moveSegments { @@ -870,13 +759,9 @@ func (suite *TaskSuite) TestTaskCanceled() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Cancel all tasks for _, task := range tasks { - task.Cancel() + task.Cancel(errors.New("mock error")) } suite.dispatchAndWait(targetNode) @@ -953,10 +838,6 @@ func (suite *TaskSuite) TestSegmentTaskStale() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks done // Dist contains channels, first task stale view := &meta.LeaderView{ @@ -982,8 +863,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() { for i, task := range tasks { if i == 0 { - suite.Equal(TaskStatusStale, task.Status()) - suite.ErrorIs(ErrTaskStale, task.Err()) + suite.Equal(TaskStatusCanceled, task.Status()) + suite.Error(task.Err()) } else { suite.Equal(TaskStatusSucceeded, task.Status()) suite.NoError(task.Err()) @@ -1025,10 +906,10 @@ func (suite *TaskSuite) TestChannelTaskReplace() { suite.NoError(err) task.SetPriority(TaskPriorityNormal) err = suite.scheduler.Add(task) - suite.ErrorIs(err, ErrConflictTaskExisted) + suite.Error(err) task.SetPriority(TaskPriorityLow) err = suite.scheduler.Add(task) - suite.ErrorIs(err, ErrConflictTaskExisted) + suite.Error(err) } // Replace the task with one with higher priority @@ -1117,10 +998,10 @@ func (suite *TaskSuite) TestSegmentTaskReplace() { suite.NoError(err) task.SetPriority(TaskPriorityNormal) err = suite.scheduler.Add(task) - suite.ErrorIs(err, ErrConflictTaskExisted) + suite.Error(err) task.SetPriority(TaskPriorityLow) err = suite.scheduler.Add(task) - suite.ErrorIs(err, ErrConflictTaskExisted) + suite.Error(err) } // Replace the task with one with higher priority @@ -1188,10 +1069,6 @@ func (suite *TaskSuite) TestNoExecutor() { suite.dispatchAndWait(targetNode) suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Other nodes' HB can't trigger the procedure of tasks - suite.dispatchAndWait(targetNode + 1) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - // Process tasks done // Dist contains channels view := &meta.LeaderView{ @@ -1204,12 +1081,7 @@ func (suite *TaskSuite) TestNoExecutor() { } suite.dist.LeaderViewManager.Update(targetNode, view) suite.dispatchAndWait(targetNode) - suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) - - for _, task := range tasks { - suite.Equal(TaskStatusStarted, task.Status()) - suite.NoError(task.Err()) - } + suite.AssertTaskNum(0, 0, 0, 0) } func (suite *TaskSuite) AssertTaskNum(process, wait, channel, segment int) { diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index bcf0abf87e..237706c19f 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -167,7 +167,7 @@ func packLoadMeta(loadType querypb.LoadType, collectionID int64, partitions ...i } } -func packSubDmChannelRequest( +func packSubChannelRequest( task *ChannelTask, action Action, schema *schemapb.CollectionSchema, @@ -189,7 +189,7 @@ func packSubDmChannelRequest( } } -func fillSubDmChannelRequest( +func fillSubChannelRequest( ctx context.Context, req *querypb.WatchDmChannelsRequest, broker meta.Broker, diff --git a/internal/util/merr/errors.go b/internal/util/merr/errors.go index 54850e3993..c373c60279 100644 --- a/internal/util/merr/errors.go +++ b/internal/util/merr/errors.go @@ -53,6 +53,7 @@ var ( ErrServiceUnavailable = newMilvusError("service unavailable", 2, true) ErrServiceMemoryLimitExceeded = newMilvusError("memory limit exceeded", 3, false) ErrServiceRequestLimitExceeded = newMilvusError("request limit exceeded", 4, true) + ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus // Collection related ErrCollectionNotFound = newMilvusError("collection not found", 100, false) @@ -69,12 +70,15 @@ var ( ErrReplicaNotFound = newMilvusError("replica not found", 400, false) // Channel related - ErrChannelNotFound = newMilvusError("channel not found", 500, false) + ErrChannelNotFound = newMilvusError("channel not found", 500, false) + ErrChannelLack = newMilvusError("channel lacks", 501, false) + ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false) // Segment related - ErrSegmentNotFound = newMilvusError("segment not found", 600, false) - ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false) - ErrSegmentLack = newMilvusError("segment lacks", 602, false) + ErrSegmentNotFound = newMilvusError("segment not found", 600, false) + ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false) + ErrSegmentLack = newMilvusError("segment lacks", 602, false) + ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false) // Index related ErrIndexNotFound = newMilvusError("index not found", 700, false) diff --git a/internal/util/merr/errors_test.go b/internal/util/merr/errors_test.go index 4b6d04974d..e27f017eb5 100644 --- a/internal/util/merr/errors_test.go +++ b/internal/util/merr/errors_test.go @@ -63,6 +63,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrServiceUnavailable("test", "test init"), ErrServiceUnavailable) s.ErrorIs(WrapErrServiceMemoryLimitExceeded(110, 100, "MLE"), ErrServiceMemoryLimitExceeded) s.ErrorIs(WrapErrServiceRequestLimitExceeded(100, "too many requests"), ErrServiceRequestLimitExceeded) + s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal) // Collection related s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound) @@ -80,11 +81,14 @@ func (s *ErrSuite) TestWrap() { // Channel related s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound) + s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack) + s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate) // Segment related s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound) s.ErrorIs(WrapErrSegmentNotLoaded(1, "failed to query"), ErrSegmentNotLoaded) s.ErrorIs(WrapErrSegmentLack(1, "lack of segment"), ErrSegmentLack) + s.ErrorIs(WrapErrSegmentReduplicate(1, "redundancy of segment"), ErrSegmentReduplicate) // Index related s.ErrorIs(WrapErrIndexNotFound("failed to get Index"), ErrIndexNotFound) diff --git a/internal/util/merr/utils.go b/internal/util/merr/utils.go index 6a4b0d70b6..19fe231b78 100644 --- a/internal/util/merr/utils.go +++ b/internal/util/merr/utils.go @@ -68,18 +68,34 @@ func Status(err error) *commonpb.Status { return &commonpb.Status{} } + code := Code(err) return &commonpb.Status{ - Code: Code(err), + Code: code, Reason: err.Error(), - // Deprecated, for compatibility, set it to UnexpectedError - ErrorCode: commonpb.ErrorCode_UnexpectedError, + // Deprecated, for compatibility + ErrorCode: oldCode(code), } } +func oldCode(code int32) commonpb.ErrorCode { + switch code { + case ErrServiceNotReady.code(): + return commonpb.ErrorCode_NotReadyServe + case ErrCollectionNotFound.code(): + return commonpb.ErrorCode_CollectionNotExists + default: + return commonpb.ErrorCode_UnexpectedError + } +} + +func Ok(status *commonpb.Status) bool { + return status.ErrorCode == commonpb.ErrorCode_Success && status.Code == 0 +} + // Error returns a error according to the given status, // returns nil if the status is a success status func Error(status *commonpb.Status) error { - if status.GetCode() == 0 && status.GetErrorCode() == commonpb.ErrorCode_Success { + if Ok(status) { return nil } @@ -125,6 +141,13 @@ func WrapErrServiceRequestLimitExceeded(limit int32, msg ...string) error { return err } +func WrapErrServiceInternal(msg string, others ...string) error { + msg = strings.Join(append([]string{msg}, others...), "; ") + err := errors.Wrap(ErrServiceInternal, msg) + + return err +} + // Collection related func WrapErrCollectionNotFound(collection any, msg ...string) error { err := wrapWithField(ErrCollectionNotFound, "collection", collection) @@ -186,6 +209,22 @@ func WrapErrChannelNotFound(name string, msg ...string) error { return err } +func WrapErrChannelLack(name string, msg ...string) error { + err := wrapWithField(ErrChannelLack, "channel", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrChannelReduplicate(name string, msg ...string) error { + err := wrapWithField(ErrChannelReduplicate, "channel", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Segment related func WrapErrSegmentNotFound(id int64, msg ...string) error { err := wrapWithField(ErrSegmentNotFound, "segment", id) @@ -211,6 +250,14 @@ func WrapErrSegmentLack(id int64, msg ...string) error { return err } +func WrapErrSegmentReduplicate(id int64, msg ...string) error { + err := wrapWithField(ErrSegmentReduplicate, "segment", id) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Index related func WrapErrIndexNotFound(msg ...string) error { err := error(ErrIndexNotFound)