diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 317d9d7cc5..fc42b7e3d3 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -99,6 +99,7 @@ type task interface { getResultInfo() *commonpb.Status updateTaskProcess() elapseSpan() time.Duration + finishContext() } type baseTask struct { @@ -298,6 +299,13 @@ func (bt *baseTask) elapseSpan() time.Duration { return bt.timeRecorder.ElapseSpan() } +// finishContext calls the cancel function for the trace ctx. +func (bt *baseTask) finishContext() { + if bt.cancel != nil { + bt.cancel() + } +} + type loadCollectionTask struct { *baseTask *querypb.LoadCollectionRequest diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index fce8538e27..1350bc64ff 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -521,6 +521,10 @@ func (scheduler *TaskScheduler) processTask(t task) error { t.postExecute(ctx) }() + // bind schedular context with task context + // cancel cannot be deferred here, since child task may rely on the parent context + ctx, _ = scheduler.BindContext(ctx) + // task preExecute span.LogFields(oplog.Int64("processTask: scheduler process PreExecute", t.getTaskID())) err = t.preExecute(ctx) @@ -710,6 +714,8 @@ func (scheduler *TaskScheduler) scheduleLoop() { triggerTask.notify(nil) } } + // calling context cancel so that bind context goroutine may quit + triggerTask.finishContext() } } } @@ -887,6 +893,26 @@ func (scheduler *TaskScheduler) Close() { scheduler.wg.Wait() } +// BindContext binds input context with shceduler context. +// the result context will be canceled when either context is done. +func (scheduler *TaskScheduler) BindContext(ctx context.Context) (context.Context, context.CancelFunc) { + // use input context as parent + nCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-scheduler.ctx.Done(): + // scheduler done + cancel() + case <-ctx.Done(): + // input ctx done + cancel() + case <-nCtx.Done(): + // result context done, cancel called to make goroutine quit + } + }() + return nCtx, cancel +} + func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) error { segmentInfosToSave := make(map[UniqueID][]*querypb.SegmentInfo) segmentInfosToRemove := make(map[UniqueID][]*querypb.SegmentInfo) diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index a7a3dcfe14..97ba5f5e18 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -21,6 +21,7 @@ import ( "fmt" "strconv" "testing" + "time" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" @@ -553,3 +554,50 @@ func Test_generateDerivedInternalTasks(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } + +func TestTaskScheduler_BindContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + s := &TaskScheduler{ + ctx: ctx, + cancel: cancel, + } + + t.Run("normal finish", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx, cancel = s.BindContext(ctx) + + cancel() // normal finish + assert.Eventually(t, func() bool { + return ctx.Err() == context.Canceled + }, time.Second, time.Millisecond*10) + }) + + t.Run("input context canceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + nctx, ncancel := s.BindContext(ctx) + defer ncancel() + + cancel() // input context cancel + + assert.Eventually(t, func() bool { + return nctx.Err() == context.Canceled + }, time.Second, time.Millisecond*10) + + }) + + t.Run("scheduler context cancel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + nctx, ncancel := s.BindContext(ctx) + defer ncancel() + + s.cancel() // scheduler cancel + + assert.Eventually(t, func() bool { + return nctx.Err() == context.Canceled + }, time.Second, time.Millisecond*10) + }) +}