Refine scheduler & executor of QueryCoord (#22761)

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/22806/head
yah01 2023-03-16 17:43:55 +08:00 committed by GitHub
parent 6e47312138
commit 3d8f0156c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 186 additions and 361 deletions

View File

@ -124,7 +124,7 @@ func (controller *CheckerController) check(ctx context.Context) {
for _, task := range tasks {
err := controller.scheduler.Add(task)
if err != nil {
task.Cancel()
task.Cancel(err)
continue
}
}

View File

@ -161,7 +161,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe
}
err = s.taskScheduler.Add(task)
if err != nil {
task.Cancel()
task.Cancel(err)
return err
}
tasks = append(tasks, task)

View File

@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/kv"
@ -1186,7 +1187,7 @@ func (suite *ServiceSuite) TestLoadBalance() {
growAction, reduceAction := actions[0], actions[1]
suite.Equal(dstNode, growAction.Node())
suite.Equal(srcNode, reduceAction.Node())
task.Cancel()
task.Cancel(nil)
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
@ -1262,7 +1263,7 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() {
suite.True(lo.Contains(segmentOnCollection[collection], reduceAction.SegmentID()))
suite.Equal(dstNode, growAction.Node())
suite.Equal(srcNode, reduceAction.Node())
t.Cancel()
t.Cancel(nil)
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
@ -1352,14 +1353,13 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() {
SealedSegmentIDs: segments,
}
suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(balanceTask task.Task) {
balanceTask.SetErr(task.ErrTaskCanceled)
balanceTask.Cancel()
balanceTask.Cancel(errors.New("mock error"))
}).Return(nil)
resp, err := server.LoadBalance(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
suite.Contains(resp.Reason, "failed to balance segments")
suite.Contains(resp.Reason, task.ErrTaskCanceled.Error())
suite.Contains(resp.Reason, "mock error")
suite.meta.ReplicaManager.AddNode(replicas[0].ID, 10)
req.SourceNodeIDs = []int64{10}

View File

@ -21,8 +21,6 @@ import (
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -147,13 +145,14 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) {
task := mergeTask.tasks[0]
action := task.Actions()[mergeTask.steps[0]]
var err error
defer func() {
canceled := task.canceled.Load()
for i := range mergeTask.tasks {
mergeTask.tasks[i].SetErr(task.Err())
if canceled {
mergeTask.tasks[i].Cancel()
if err != nil {
for i := range mergeTask.tasks {
mergeTask.tasks[i].Cancel(err)
}
}
for i := range mergeTask.tasks {
ex.removeTask(mergeTask.tasks[i], mergeTask.steps[i])
}
}()
@ -178,31 +177,25 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) {
channel := mergeTask.req.GetInfos()[0].GetInsertChannel()
leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), channel)
if !ok {
msg := "no shard leader for the segment to execute loading"
task.SetErr(utils.WrapError(msg, ErrTaskStale))
log.Warn(msg, zap.String("shard", channel))
err = merr.WrapErrChannelNotFound(channel, "shard delegator not found")
log.Warn("no shard leader for the segment to execute loading", zap.Error(task.Err()))
return
}
log.Info("load segments...")
status, err := ex.cluster.LoadSegments(task.Context(), leader, mergeTask.req)
if err != nil {
log.Warn("failed to load segment, it may be a false failure", zap.Error(err))
log.Warn("failed to load segment", zap.Error(err))
return
}
err = merr.Error(status)
if errors.Is(err, merr.ErrServiceMemoryLimitExceeded) {
log.Warn("insufficient memory to load segment", zap.Error(err))
task.SetErr(err)
task.Cancel()
return
}
if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to load segment", zap.String("reason", status.GetReason()))
if !merr.Ok(status) {
err = merr.Error(status)
log.Warn("failed to load segment", zap.Error(err))
return
}
elapsed := time.Since(startTs)
log.Info("load segments done", zap.Int64("taskID", task.ID()), zap.Duration("timeTaken", elapsed))
log.Info("load segments done", zap.Duration("elapsed", elapsed))
}
func (ex *Executor) removeTask(task Task, step int) {
@ -243,8 +236,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
var err error
defer func() {
if err != nil {
task.SetErr(err)
task.Cancel()
task.Cancel(err)
ex.removeTask(task, step)
}
}()
@ -253,7 +245,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
schema, err := ex.broker.GetCollectionSchema(ctx, task.CollectionID())
if err != nil {
log.Warn("failed to get schema of collection", zap.Error(err))
task.SetErr(err)
return err
}
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID())
@ -283,8 +274,8 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), segment.GetInsertChannel())
if !ok {
msg := "no shard leader for the segment to execute loading"
err = utils.WrapError(msg, ErrTaskStale)
log.Warn(msg, zap.String("shard", segment.GetInsertChannel()))
err = merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "shard delegator not found")
log.Warn(msg, zap.Error(err))
return err
}
log = log.With(zap.Int64("shardLeader", leader))
@ -386,8 +377,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error {
var err error
defer func() {
if err != nil {
task.SetErr(err)
task.Cancel()
task.Cancel(err)
}
}()
@ -413,12 +403,12 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error {
if dmChannel == nil {
msg := "channel does not exist in next target, skip it"
log.Warn(msg, zap.String("channelName", action.ChannelName()))
return errors.New(msg)
return merr.WrapErrChannelReduplicate(action.ChannelName())
}
req := packSubDmChannelRequest(task, action, schema, loadMeta, dmChannel)
err = fillSubDmChannelRequest(ctx, req, ex.broker)
req := packSubChannelRequest(task, action, schema, loadMeta, dmChannel)
err = fillSubChannelRequest(ctx, req, ex.broker)
if err != nil {
log.Warn("failed to subscribe DmChannel, failed to fill the request with segments",
log.Warn("failed to subscribe channel, failed to fill the request with segments",
zap.Error(err))
return err
}
@ -430,16 +420,16 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error {
)
status, err := ex.cluster.WatchDmChannels(ctx, action.Node(), req)
if err != nil {
log.Warn("failed to subscribe DmChannel, it may be a false failure", zap.Error(err))
log.Warn("failed to subscribe channel, it may be a false failure", zap.Error(err))
return err
}
if status.ErrorCode != commonpb.ErrorCode_Success {
err = utils.WrapError("failed to subscribe DmChannel", ErrFailedResponse)
log.Warn("failed to subscribe DmChannel", zap.String("reason", status.GetReason()))
if !merr.Ok(status) {
err = merr.Error(status)
log.Warn("failed to subscribe channel", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("subscribe DmChannel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
log.Info("subscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
return nil
}
@ -459,8 +449,7 @@ func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) error {
var err error
defer func() {
if err != nil {
task.SetErr(err)
task.Cancel()
task.Cancel(err)
}
}()
@ -470,16 +459,16 @@ func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) error {
log.Info("unsubscribe channel...")
status, err := ex.cluster.UnsubDmChannel(ctx, action.Node(), req)
if err != nil {
log.Warn("failed to unsubscribe DmChannel, it may be a false failure", zap.Error(err))
log.Warn("failed to unsubscribe channel, it may be a false failure", zap.Error(err))
return err
}
if status.ErrorCode != commonpb.ErrorCode_Success {
err = utils.WrapError("failed to unsubscribe DmChannel", ErrFailedResponse)
log.Warn("failed to unsubscribe DmChannel", zap.String("reason", status.GetReason()))
if !merr.Ok(status) {
err = merr.Error(status)
log.Warn("failed to unsubscribe channel", zap.Error(err))
return err
}
elapsed := time.Since(startTs)
log.Info("unsubscribe DmChannel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
log.Info("unsubscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed))
return nil
}

View File

@ -91,15 +91,6 @@ func (merger *Merger[K, R]) schedule(ctx context.Context) {
}()
}
func (merger *Merger[K, R]) isStopped() bool {
select {
case <-merger.stopCh:
return true
default:
return false
}
}
func (merger *Merger[K, R]) Add(task MergeableTask[K, R]) {
merger.waitQueue <- task
}

View File

@ -21,8 +21,7 @@ import (
"fmt"
"runtime"
"sync"
"github.com/cockroachdb/errors"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
@ -30,8 +29,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/merr"
. "github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -43,24 +42,6 @@ const (
TaskTypeMove
)
var (
ErrConflictTaskExisted = errors.New("ConflictTaskExisted")
// The task is canceled or timeout
ErrTaskCanceled = errors.New("TaskCanceled")
// The target node is offline,
// or the target segment is not in TargetManager,
// or the target channel is not in TargetManager
ErrTaskStale = errors.New("TaskStale")
// ErrInsufficientMemory returns insufficient memory error.
ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad")
ErrFailedResponse = errors.New("RpcFailed")
ErrTaskAlreadyDone = errors.New("TaskAlreadyDone")
)
type Type = int32
type replicaSegmentIndex struct {
@ -167,7 +148,7 @@ func NewScheduler(ctx context.Context,
broker meta.Broker,
cluster session.Cluster,
nodeMgr *session.NodeManager) *taskScheduler {
id := int64(0)
id := time.Now().UnixMilli()
return &taskScheduler{
ctx: ctx,
executors: make(map[int64]*Executor),
@ -276,20 +257,12 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
zap.Int64("newID", task.ID()),
zap.Int32("newPriority", task.Priority()),
)
old.SetStatus(TaskStatusCanceled)
old.SetErr(utils.WrapError("replaced with the other one with higher priority", ErrTaskCanceled))
old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority"))
scheduler.remove(old)
return nil
}
return ErrConflictTaskExisted
}
if GetTaskType(task) == TaskTypeGrow {
nodesWithSegment := scheduler.distMgr.LeaderViewManager.GetSealedSegmentDist(task.SegmentID())
replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithSegment)
if _, ok := replicaNodeMap[task.ReplicaID()]; ok {
return ErrTaskAlreadyDone
}
return merr.WrapErrServiceInternal("task with the same segment exists")
}
case *ChannelTask:
@ -302,21 +275,12 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
zap.Int64("newID", task.ID()),
zap.Int32("newPriority", task.Priority()),
)
old.SetStatus(TaskStatusCanceled)
old.SetErr(utils.WrapError("replaced with the other one with higher priority", ErrTaskCanceled))
old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority"))
scheduler.remove(old)
return nil
}
return ErrConflictTaskExisted
}
if GetTaskType(task) == TaskTypeGrow {
nodesWithChannel := scheduler.distMgr.LeaderViewManager.GetChannelDist(task.Channel())
replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel)
if _, ok := replicaNodeMap[task.ReplicaID()]; ok {
return ErrTaskAlreadyDone
}
return merr.WrapErrServiceInternal("task with the same channel exists")
}
default:
@ -326,43 +290,19 @@ func (scheduler *taskScheduler) preAdd(task Task) error {
return nil
}
func (scheduler *taskScheduler) promote(task Task) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("source", task.SourceID()),
)
err := scheduler.prePromote(task)
if err != nil {
log.Info("failed to promote task", zap.Error(err))
return err
}
scheduler.processQueue.Add(task)
task.SetStatus(TaskStatusStarted)
return nil
}
func (scheduler *taskScheduler) tryPromoteAll() {
// Promote waiting tasks
toPromote := make([]Task, 0, scheduler.waitQueue.Len())
toRemove := make([]Task, 0)
scheduler.waitQueue.Range(func(task Task) bool {
err := scheduler.promote(task)
if err != nil {
task.SetStatus(TaskStatusCanceled)
if errors.Is(err, ErrTaskStale) { // Task canceled or stale
task.SetStatus(TaskStatusStale)
}
task.Cancel(err)
toRemove = append(toRemove, task)
log.Warn("failed to promote task",
zap.Int64("taskID", task.ID()),
zap.Error(err),
)
task.SetErr(err)
toRemove = append(toRemove, task)
} else {
toPromote = append(toPromote, task)
}
@ -384,13 +324,21 @@ func (scheduler *taskScheduler) tryPromoteAll() {
}
}
func (scheduler *taskScheduler) prePromote(task Task) error {
if scheduler.checkCanceled(task) {
return ErrTaskCanceled
} else if scheduler.checkStale(task) {
return ErrTaskStale
func (scheduler *taskScheduler) promote(task Task) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("source", task.SourceID()),
)
if err := scheduler.check(task); err != nil {
log.Info("failed to promote task", zap.Error(err))
return err
}
scheduler.processQueue.Add(task)
task.SetStatus(TaskStatusStarted)
return nil
}
@ -400,6 +348,8 @@ func (scheduler *taskScheduler) Dispatch(node int64) {
log.Info("scheduler stopped")
default:
scheduler.rwmutex.Lock()
defer scheduler.rwmutex.Unlock()
scheduler.schedule(node)
}
}
@ -458,13 +408,10 @@ func (scheduler *taskScheduler) GetNodeSegmentCntDelta(nodeID int64) int {
}
// schedule selects some tasks to execute, follow these steps for each started selected tasks:
// 1. check whether this task is stale, set status to failed if stale
// 1. check whether this task is stale, set status to canceled if stale
// 2. step up the task's actions, set status to succeeded if all actions finished
// 3. execute the current action of task
func (scheduler *taskScheduler) schedule(node int64) {
scheduler.rwmutex.Lock()
defer scheduler.rwmutex.Unlock()
if scheduler.tasks.Len() == 0 {
return
}
@ -486,7 +433,7 @@ func (scheduler *taskScheduler) schedule(node int64) {
toProcess := make([]Task, 0)
toRemove := make([]Task, 0)
scheduler.processQueue.Range(func(task Task) bool {
if scheduler.isRelated(task, node) && scheduler.preProcess(task) {
if scheduler.preProcess(task) && scheduler.isRelated(task, node) {
toProcess = append(toProcess, task)
}
if task.Status() != TaskStatusStarted {
@ -561,13 +508,9 @@ func (scheduler *taskScheduler) isRelated(task Task, node int64) bool {
// return true if the task should be executed,
// false otherwise
func (scheduler *taskScheduler) preProcess(task Task) bool {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int32("type", GetTaskType(task)),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("source", task.SourceID()),
)
if task.Status() != TaskStatusStarted {
return false
}
actions, step := task.Actions(), task.Step()
for step < len(actions) && actions[step].IsFinished(scheduler.distMgr) {
@ -575,31 +518,12 @@ func (scheduler *taskScheduler) preProcess(task Task) bool {
step++
}
if step == len(actions) {
step--
}
executor, ok := scheduler.executors[actions[step].Node()]
if !ok {
log.Warn("no executor for QueryNode",
zap.Int("step", step),
zap.Int64("nodeID", actions[step].Node()))
return false
}
if task.IsFinished(scheduler.distMgr) {
if !executor.Exist(task.ID()) {
task.SetStatus(TaskStatusSucceeded)
task.SetStatus(TaskStatusSucceeded)
} else {
if err := scheduler.check(task); err != nil {
task.Cancel(err)
}
return false
} else if scheduler.checkCanceled(task) {
task.SetStatus(TaskStatusCanceled)
if task.Err() == nil {
task.SetErr(ErrTaskCanceled)
}
} else if scheduler.checkStale(task) {
task.SetStatus(TaskStatusStale)
task.SetErr(ErrTaskStale)
}
return task.Status() == TaskStatusStarted
@ -633,7 +557,7 @@ func (scheduler *taskScheduler) process(task Task) bool {
case TaskStatusSucceeded:
log.Info("task succeeded")
case TaskStatusCanceled, TaskStatusStale:
case TaskStatusCanceled:
log.Warn("failed to execute task", zap.Error(task.Err()))
default:
@ -643,6 +567,15 @@ func (scheduler *taskScheduler) process(task Task) bool {
return false
}
func (scheduler *taskScheduler) check(task Task) error {
err := task.Context().Err()
if err == nil {
err = scheduler.checkStale(task)
}
return err
}
func (scheduler *taskScheduler) RemoveByNode(node int64) {
scheduler.rwmutex.Lock()
defer scheduler.rwmutex.Unlock()
@ -670,7 +603,7 @@ func (scheduler *taskScheduler) remove(task Task) {
zap.Int64("replicaID", task.ReplicaID()),
zap.Int32("taskStatus", task.Status()),
)
task.Cancel()
task.Cancel(nil)
scheduler.tasks.Remove(task.ID())
scheduler.waitQueue.Remove(task)
scheduler.processQueue.Remove(task)
@ -692,28 +625,10 @@ func (scheduler *taskScheduler) remove(task Task) {
}
metrics.QueryCoordTaskNum.WithLabelValues().Set(float64(scheduler.tasks.Len()))
log.Debug("task removed")
log.Debug("task removed", zap.Stack("stack"))
}
func (scheduler *taskScheduler) checkCanceled(task Task) bool {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
zap.Int64("replicaID", task.ReplicaID()),
zap.Int64("source", task.SourceID()),
)
select {
case <-task.Context().Done():
log.Warn("the task is timeout or canceled")
return true
default:
return false
}
}
func (scheduler *taskScheduler) checkStale(task Task) bool {
func (scheduler *taskScheduler) checkStale(task Task) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
@ -723,13 +638,13 @@ func (scheduler *taskScheduler) checkStale(task Task) bool {
switch task := task.(type) {
case *SegmentTask:
if scheduler.checkSegmentTaskStale(task) {
return true
if err := scheduler.checkSegmentTaskStale(task); err != nil {
return err
}
case *ChannelTask:
if scheduler.checkChannelTaskStale(task) {
return true
if err := scheduler.checkChannelTaskStale(task); err != nil {
return err
}
default:
@ -743,14 +658,14 @@ func (scheduler *taskScheduler) checkStale(task Task) bool {
if scheduler.nodeMgr.Get(action.Node()) == nil {
log.Warn("the task is stale, the target node is offline")
return true
return merr.WrapErrNodeNotFound(action.Node())
}
}
return false
return nil
}
func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool {
func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
@ -773,18 +688,18 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool {
zap.Int64("segment", task.segmentID),
zap.Int32("taskType", taskType),
)
return true
return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment")
}
replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node())
if replica == nil {
log.Warn("task stale due to replica not found")
return true
return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID")
}
_, ok := scheduler.distMgr.GetShardLeader(replica, segment.GetInsertChannel())
if !ok {
log.Warn("task stale due to leader not found")
return true
return merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "failed to get shard delegator")
}
case ActionTypeReduce:
@ -792,10 +707,10 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) bool {
// the task should succeeded if the segment not exists
}
}
return false
return nil
}
func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool {
func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error {
log := log.With(
zap.Int64("taskID", task.ID()),
zap.Int64("collectionID", task.CollectionID()),
@ -809,7 +724,7 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool {
if scheduler.targetMgr.GetDmChannel(task.collectionID, task.Channel(), meta.NextTarget) == nil {
log.Warn("the task is stale, the channel to subscribe not exists in targets",
zap.String("channel", task.Channel()))
return true
return merr.WrapErrChannelReduplicate(task.Channel(), "target doesn't contain this channel")
}
case ActionTypeReduce:
@ -817,5 +732,5 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) bool {
// the task should succeeded if the channel not exists
}
}
return false
return nil
}

View File

@ -37,7 +37,6 @@ const (
TaskStatusStarted
TaskStatusSucceeded
TaskStatusCanceled
TaskStatusStale
)
const (
@ -61,11 +60,10 @@ type Task interface {
Status() Status
SetStatus(status Status)
Err() error
SetErr(err error)
Priority() Priority
SetPriority(priority Priority)
Cancel()
Cancel(err error)
Wait() error
Actions() []Action
Step() int
@ -159,16 +157,21 @@ func (task *baseTask) SetPriority(priority Priority) {
}
func (task *baseTask) Err() error {
return task.err
select {
case <-task.doneCh:
return task.err
default:
return nil
}
}
func (task *baseTask) SetErr(err error) {
task.err = err
}
func (task *baseTask) Cancel() {
if task.canceled.CAS(false, true) {
func (task *baseTask) Cancel(err error) {
if task.canceled.CompareAndSwap(false, true) {
task.cancel()
if task.Status() != TaskStatusSucceeded {
task.SetStatus(TaskStatusCanceled)
}
task.err = err
close(task.doneCh)
}
}

View File

@ -174,48 +174,6 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) {
}
}
func (suite *TaskSuite) TestSubmitDuplicateSubscribeChannelTask() {
ctx := context.Background()
timeout := 10 * time.Second
targetNode := int64(3)
tasks := []Task{}
dmChannels := make([]*datapb.VchannelInfo, 0)
for _, channel := range suite.subChannels {
dmChannels = append(dmChannels, &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: channel,
UnflushedSegmentIds: []int64{suite.growingSegments[channel]},
})
task, err := NewChannelTask(
ctx,
timeout,
0,
suite.collection,
suite.replica,
NewChannelAction(targetNode, ActionTypeGrow, channel),
)
suite.NoError(err)
tasks = append(tasks, task)
}
views := make([]*meta.LeaderView, 0)
for _, channel := range suite.subChannels {
views = append(views, &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Channel: channel,
})
}
suite.dist.LeaderViewManager.Update(targetNode, views...)
for _, task := range tasks {
err := suite.scheduler.Add(task)
suite.Error(err)
suite.ErrorIs(err, ErrTaskAlreadyDone)
}
}
func (suite *TaskSuite) TestSubscribeChannelTask() {
ctx := context.Background()
timeout := 10 * time.Second
@ -284,10 +242,6 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(len(suite.subChannels), 0, len(suite.subChannels), 0)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(len(suite.subChannels), 0, len(suite.subChannels), 0)
// Process tasks done
// Dist contains channels
views := make([]*meta.LeaderView, 0)
@ -354,10 +308,6 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(1, 0, 1, 0)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(1, 0, 1, 0)
// Update dist
suite.dist.LeaderViewManager.Update(targetNode)
suite.dispatchAndWait(targetNode)
@ -433,10 +383,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
// Dist contains channels
view := &meta.LeaderView{
@ -457,47 +403,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
}
}
func (suite *TaskSuite) TestSubmitDuplicateLoadSegmentTask() {
ctx := context.Background()
timeout := 10 * time.Second
targetNode := int64(3)
channel := &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test",
}
tasks := []Task{}
for _, segment := range suite.loadSegments {
task, err := NewSegmentTask(
ctx,
timeout,
0,
suite.collection,
suite.replica,
NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment),
)
suite.NoError(err)
tasks = append(tasks, task)
}
// Process tasks done
// Dist contains channels
view := &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Segments: map[int64]*querypb.SegmentDist{},
}
for _, segment := range suite.loadSegments {
view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0}
}
suite.dist.LeaderViewManager.Update(targetNode, view)
for _, task := range tasks {
err := suite.scheduler.Add(task)
suite.Error(err)
suite.ErrorIs(err, ErrTaskAlreadyDone)
}
}
func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
ctx := context.Background()
timeout := 10 * time.Second
@ -559,10 +464,6 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
// Dist contains channels
time.Sleep(timeout)
@ -629,10 +530,6 @@ func (suite *TaskSuite) TestReleaseSegmentTask() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
suite.dist.LeaderViewManager.Update(targetNode)
suite.dispatchAndWait(targetNode)
@ -684,10 +581,6 @@ func (suite *TaskSuite) TestReleaseGrowingSegmentTask() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum-1, 0, 0, segmentsNum-1)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum-1, 0, 0, segmentsNum-1)
// Release done
suite.dist.LeaderViewManager.Update(targetNode)
@ -784,10 +677,6 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
suite.dispatchAndWait(leader)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(-1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks, target node contains the segment
view = view.Clone()
for _, segment := range suite.moveSegments {
@ -870,13 +759,9 @@ func (suite *TaskSuite) TestTaskCanceled() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Cancel all tasks
for _, task := range tasks {
task.Cancel()
task.Cancel(errors.New("mock error"))
}
suite.dispatchAndWait(targetNode)
@ -953,10 +838,6 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
// Dist contains channels, first task stale
view := &meta.LeaderView{
@ -982,8 +863,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
for i, task := range tasks {
if i == 0 {
suite.Equal(TaskStatusStale, task.Status())
suite.ErrorIs(ErrTaskStale, task.Err())
suite.Equal(TaskStatusCanceled, task.Status())
suite.Error(task.Err())
} else {
suite.Equal(TaskStatusSucceeded, task.Status())
suite.NoError(task.Err())
@ -1025,10 +906,10 @@ func (suite *TaskSuite) TestChannelTaskReplace() {
suite.NoError(err)
task.SetPriority(TaskPriorityNormal)
err = suite.scheduler.Add(task)
suite.ErrorIs(err, ErrConflictTaskExisted)
suite.Error(err)
task.SetPriority(TaskPriorityLow)
err = suite.scheduler.Add(task)
suite.ErrorIs(err, ErrConflictTaskExisted)
suite.Error(err)
}
// Replace the task with one with higher priority
@ -1117,10 +998,10 @@ func (suite *TaskSuite) TestSegmentTaskReplace() {
suite.NoError(err)
task.SetPriority(TaskPriorityNormal)
err = suite.scheduler.Add(task)
suite.ErrorIs(err, ErrConflictTaskExisted)
suite.Error(err)
task.SetPriority(TaskPriorityLow)
err = suite.scheduler.Add(task)
suite.ErrorIs(err, ErrConflictTaskExisted)
suite.Error(err)
}
// Replace the task with one with higher priority
@ -1188,10 +1069,6 @@ func (suite *TaskSuite) TestNoExecutor() {
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Other nodes' HB can't trigger the procedure of tasks
suite.dispatchAndWait(targetNode + 1)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
// Process tasks done
// Dist contains channels
view := &meta.LeaderView{
@ -1204,12 +1081,7 @@ func (suite *TaskSuite) TestNoExecutor() {
}
suite.dist.LeaderViewManager.Update(targetNode, view)
suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)
for _, task := range tasks {
suite.Equal(TaskStatusStarted, task.Status())
suite.NoError(task.Err())
}
suite.AssertTaskNum(0, 0, 0, 0)
}
func (suite *TaskSuite) AssertTaskNum(process, wait, channel, segment int) {

View File

@ -167,7 +167,7 @@ func packLoadMeta(loadType querypb.LoadType, collectionID int64, partitions ...i
}
}
func packSubDmChannelRequest(
func packSubChannelRequest(
task *ChannelTask,
action Action,
schema *schemapb.CollectionSchema,
@ -189,7 +189,7 @@ func packSubDmChannelRequest(
}
}
func fillSubDmChannelRequest(
func fillSubChannelRequest(
ctx context.Context,
req *querypb.WatchDmChannelsRequest,
broker meta.Broker,

View File

@ -53,6 +53,7 @@ var (
ErrServiceUnavailable = newMilvusError("service unavailable", 2, true)
ErrServiceMemoryLimitExceeded = newMilvusError("memory limit exceeded", 3, false)
ErrServiceRequestLimitExceeded = newMilvusError("request limit exceeded", 4, true)
ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus
// Collection related
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
@ -69,12 +70,15 @@ var (
ErrReplicaNotFound = newMilvusError("replica not found", 400, false)
// Channel related
ErrChannelNotFound = newMilvusError("channel not found", 500, false)
ErrChannelNotFound = newMilvusError("channel not found", 500, false)
ErrChannelLack = newMilvusError("channel lacks", 501, false)
ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false)
// Segment related
ErrSegmentNotFound = newMilvusError("segment not found", 600, false)
ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false)
ErrSegmentLack = newMilvusError("segment lacks", 602, false)
ErrSegmentNotFound = newMilvusError("segment not found", 600, false)
ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false)
ErrSegmentLack = newMilvusError("segment lacks", 602, false)
ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false)
// Index related
ErrIndexNotFound = newMilvusError("index not found", 700, false)

View File

@ -63,6 +63,7 @@ func (s *ErrSuite) TestWrap() {
s.ErrorIs(WrapErrServiceUnavailable("test", "test init"), ErrServiceUnavailable)
s.ErrorIs(WrapErrServiceMemoryLimitExceeded(110, 100, "MLE"), ErrServiceMemoryLimitExceeded)
s.ErrorIs(WrapErrServiceRequestLimitExceeded(100, "too many requests"), ErrServiceRequestLimitExceeded)
s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal)
// Collection related
s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound)
@ -80,11 +81,14 @@ func (s *ErrSuite) TestWrap() {
// Channel related
s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound)
s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack)
s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate)
// Segment related
s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound)
s.ErrorIs(WrapErrSegmentNotLoaded(1, "failed to query"), ErrSegmentNotLoaded)
s.ErrorIs(WrapErrSegmentLack(1, "lack of segment"), ErrSegmentLack)
s.ErrorIs(WrapErrSegmentReduplicate(1, "redundancy of segment"), ErrSegmentReduplicate)
// Index related
s.ErrorIs(WrapErrIndexNotFound("failed to get Index"), ErrIndexNotFound)

View File

@ -68,18 +68,34 @@ func Status(err error) *commonpb.Status {
return &commonpb.Status{}
}
code := Code(err)
return &commonpb.Status{
Code: Code(err),
Code: code,
Reason: err.Error(),
// Deprecated, for compatibility, set it to UnexpectedError
ErrorCode: commonpb.ErrorCode_UnexpectedError,
// Deprecated, for compatibility
ErrorCode: oldCode(code),
}
}
func oldCode(code int32) commonpb.ErrorCode {
switch code {
case ErrServiceNotReady.code():
return commonpb.ErrorCode_NotReadyServe
case ErrCollectionNotFound.code():
return commonpb.ErrorCode_CollectionNotExists
default:
return commonpb.ErrorCode_UnexpectedError
}
}
func Ok(status *commonpb.Status) bool {
return status.ErrorCode == commonpb.ErrorCode_Success && status.Code == 0
}
// Error returns a error according to the given status,
// returns nil if the status is a success status
func Error(status *commonpb.Status) error {
if status.GetCode() == 0 && status.GetErrorCode() == commonpb.ErrorCode_Success {
if Ok(status) {
return nil
}
@ -125,6 +141,13 @@ func WrapErrServiceRequestLimitExceeded(limit int32, msg ...string) error {
return err
}
func WrapErrServiceInternal(msg string, others ...string) error {
msg = strings.Join(append([]string{msg}, others...), "; ")
err := errors.Wrap(ErrServiceInternal, msg)
return err
}
// Collection related
func WrapErrCollectionNotFound(collection any, msg ...string) error {
err := wrapWithField(ErrCollectionNotFound, "collection", collection)
@ -186,6 +209,22 @@ func WrapErrChannelNotFound(name string, msg ...string) error {
return err
}
func WrapErrChannelLack(name string, msg ...string) error {
err := wrapWithField(ErrChannelLack, "channel", name)
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
func WrapErrChannelReduplicate(name string, msg ...string) error {
err := wrapWithField(ErrChannelReduplicate, "channel", name)
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
// Segment related
func WrapErrSegmentNotFound(id int64, msg ...string) error {
err := wrapWithField(ErrSegmentNotFound, "segment", id)
@ -211,6 +250,14 @@ func WrapErrSegmentLack(id int64, msg ...string) error {
return err
}
func WrapErrSegmentReduplicate(id int64, msg ...string) error {
err := wrapWithField(ErrSegmentReduplicate, "segment", id)
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
// Index related
func WrapErrIndexNotFound(msg ...string) error {
err := error(ErrIndexNotFound)