603 lines
17 KiB
603 lines
17 KiB
package executor
import (
icontext "github.com/influxdata/influxdb/v2/context"
tracetest "github.com/influxdata/influxdb/v2/kit/tracing/testing"
func TestMain(m *testing.M) {
var code int
func() {
defer tracetest.SetupInMemoryTracing("task_backend_tests")()
code = m.Run()
type tes struct {
svc *fakeQueryService
ex *Executor
metrics *ExecutorMetrics
i *kv.Service
tcs *taskControlService
tc testCreds
func taskExecutorSystem(t *testing.T) tes {
var (
aqs = newFakeQueryService()
qs = query.QueryServiceBridge{
AsyncQueryService: aqs,
ctx = context.Background()
logger = zaptest.NewLogger(t)
store = inmem.NewKVStore()
if err := all.Up(ctx, logger, store); err != nil {
ctrl := gomock.NewController(t)
ps := mock.NewMockPermissionService(ctrl)
ps.EXPECT().FindPermissionForUser(gomock.Any(), gomock.Any()).Return(influxdb.PermissionSet{}, nil).AnyTimes()
tenantStore := tenant.NewStore(store)
tenantSvc := tenant.NewService(tenantStore)
authStore, err := authorization.NewStore(store)
require.NoError(t, err)
authSvc := authorization.NewService(authStore, tenantSvc)
var (
svc = kv.NewService(logger, store, tenantSvc, kv.ServiceConfig{
FluxLanguageService: fluxlang.DefaultService,
tcs = &taskControlService{TaskControlService: svc}
ex, metrics = NewExecutor(zaptest.NewLogger(t), qs, ps, svc, tcs)
return tes{
svc: aqs,
ex: ex,
metrics: metrics,
i: svc,
tcs: tcs,
tc: createCreds(t, tenantSvc, tenantSvc, authSvc),
func TestTaskExecutor(t *testing.T) {
t.Run("QuerySuccess", testQuerySuccess)
t.Run("QueryFailure", testQueryFailure)
t.Run("ManualRun", testManualRun)
t.Run("ResumeRun", testResumingRun)
t.Run("WorkerLimit", testWorkerLimit)
t.Run("LimitFunc", testLimitFunc)
t.Run("Metrics", testMetrics)
t.Run("IteratorFailure", testIteratorFailure)
t.Run("ErrorHandling", testErrorHandling)
func testQuerySuccess(t *testing.T) {
tes := taskExecutorSystem(t)
var (
script = fmt.Sprintf(fmtTestScript, t.Name())
ctx = icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
span = opentracing.GlobalTracer().StartSpan("test-span")
ctx = opentracing.ContextWithSpan(ctx, span)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
promiseID := platform.ID(promise.ID())
run, err := tes.i.FindRunByID(context.Background(), task.ID, promiseID)
if err != nil {
if run.ID != promiseID {
t.Fatal("promise and run dont match")
if run.RunAt != time.Unix(126, 0).UTC() {
t.Fatalf("did not correctly set RunAt value, got: %v", run.RunAt)
tes.svc.WaitForQueryLive(t, script)
if got := promise.Error(); got != nil {
// confirm run is removed from in-mem store
run, err = tes.i.FindRunByID(context.Background(), task.ID, run.ID)
if run != nil || err == nil || !strings.Contains(err.Error(), "run not found") {
t.Fatal("run was returned when it should have been removed from kv")
// ensure the run returned by TaskControlService.FinishRun(...)
// has run logs formatted as expected
if run = tes.tcs.run; run == nil {
t.Fatal("expected run returned by FinishRun to not be nil")
if len(run.Log) < 3 {
t.Fatalf("expected 3 run logs, found %d", len(run.Log))
sctx := span.Context().(jaeger.SpanContext)
expectedMessage := fmt.Sprintf("trace_id=%s is_sampled=true", sctx.TraceID())
if expectedMessage != run.Log[1].Message {
t.Errorf("expected %q, found %q", expectedMessage, run.Log[1].Message)
func testQueryFailure(t *testing.T) {
tes := taskExecutorSystem(t)
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
promiseID := platform.ID(promise.ID())
run, err := tes.i.FindRunByID(context.Background(), task.ID, promiseID)
if err != nil {
if run.ID != promiseID {
t.Fatal("promise and run dont match")
tes.svc.WaitForQueryLive(t, script)
tes.svc.FailQuery(script, errors.New("blargyblargblarg"))
if got := promise.Error(); got == nil {
t.Fatal("got no error when I should have")
func testManualRun(t *testing.T) {
tes := taskExecutorSystem(t)
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
manualRun, err := tes.i.ForceRun(ctx, task.ID, 123)
if err != nil {
mrs, err := tes.i.ManualRuns(ctx, task.ID)
if err != nil {
if len(mrs) != 1 {
t.Fatal("manual run not created by force run")
promise, err := tes.ex.ManualRun(ctx, task.ID, manualRun.ID)
if err != nil {
run, err := tes.i.FindRunByID(context.Background(), task.ID, promise.ID())
if err != nil {
if run.ID != promise.ID() || manualRun.ID != promise.ID() {
t.Fatal("promise and run and manual run dont match")
tes.svc.WaitForQueryLive(t, script)
if got := promise.Error(); got != nil {
func testResumingRun(t *testing.T) {
tes := taskExecutorSystem(t)
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
stalledRun, err := tes.i.CreateRun(ctx, task.ID, time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
promise, err := tes.ex.ResumeCurrentRun(ctx, task.ID, stalledRun.ID)
if err != nil {
// ensure that it doesn't recreate a promise
if _, err := tes.ex.ResumeCurrentRun(ctx, task.ID, stalledRun.ID); err != taskmodel.ErrRunNotFound {
t.Fatal("failed to error when run has already been resumed")
run, err := tes.i.FindRunByID(context.Background(), task.ID, promise.ID())
if err != nil {
if run.ID != promise.ID() || stalledRun.ID != promise.ID() {
t.Fatal("promise and run and manual run dont match")
tes.svc.WaitForQueryLive(t, script)
if got := promise.Error(); got != nil {
func testWorkerLimit(t *testing.T) {
tes := taskExecutorSystem(t)
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
if len(tes.ex.workerLimit) != 1 {
t.Fatal("expected a worker to be started")
tes.svc.WaitForQueryLive(t, script)
tes.svc.FailQuery(script, errors.New("blargyblargblarg"))
if got := promise.Error(); got == nil {
t.Fatal("got no error when I should have")
func testLimitFunc(t *testing.T) {
tes := taskExecutorSystem(t)
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
forcedErr := errors.New("forced")
forcedQueryErr := taskmodel.ErrQueryError(forcedErr)
count := 0
tes.ex.SetLimitFunc(func(*taskmodel.Task, *taskmodel.Run) error {
if count < 2 {
return errors.New("not there yet")
return nil
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
if got := promise.Error(); got.Error() != forcedQueryErr.Error() {
t.Fatal("failed to get failure from forced error")
if count != 2 {
t.Fatalf("failed to call limitFunc enough times: %d", count)
func testMetrics(t *testing.T) {
tes := taskExecutorSystem(t)
metrics := tes.metrics
reg := prom.NewRegistry(zaptest.NewLogger(t))
mg := promtest.MustGather(t, reg)
m := promtest.MustFindMetric(t, mg, "task_executor_total_runs_active", nil)
assert.EqualValues(t, 0, *m.Gauge.Value, "unexpected number of active runs")
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
assert.NoError(t, err)
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
assert.NoError(t, err)
promiseID := promise.ID()
run, err := tes.i.FindRunByID(context.Background(), task.ID, promiseID)
assert.NoError(t, err)
assert.EqualValues(t, promiseID, run.ID, "promise and run dont match")
tes.svc.WaitForQueryLive(t, script)
mg = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mg, "task_executor_total_runs_active", nil)
assert.EqualValues(t, 1, *m.Gauge.Value, "unexpected number of active runs")
// N.B. You might think the _runs_complete and _runs_active metrics are updated atomically,
// but that's not the case. As a task run completes and is being cleaned up, there's a small
// window where it can be counted under both metrics.
// Our CI is very good at hitting this window, causing failures when we assert on the metric
// values below. We sleep a small amount before gathering metrics to avoid flaky errors.
time.Sleep(500 * time.Millisecond)
mg = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mg, "task_executor_total_runs_complete", map[string]string{"task_type": "", "status": "success"})
assert.EqualValues(t, 1, *m.Counter.Value, "unexpected number of successful runs")
m = promtest.MustFindMetric(t, mg, "task_executor_total_runs_active", nil)
assert.EqualValues(t, 0, *m.Gauge.Value, "unexpected number of active runs")
assert.NoError(t, promise.Error())
// manual runs metrics
mt, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
assert.NoError(t, err)
scheduledFor := int64(123)
r, err := tes.i.ForceRun(ctx, mt.ID, scheduledFor)
assert.NoError(t, err)
_, err = tes.ex.ManualRun(ctx, mt.ID, r.ID)
assert.NoError(t, err)
mg = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mg, "task_executor_manual_runs_counter", map[string]string{"taskID": mt.ID.String()})
assert.EqualValues(t, 1, *m.Counter.Value, "unexpected number of manual runs")
m = promtest.MustFindMetric(t, mg, "task_executor_run_latency_seconds", map[string]string{"task_type": ""})
assert.GreaterOrEqual(t, *m.Histogram.SampleCount, uint64(1), "run latency metric not found")
assert.Greater(t, *m.Histogram.SampleSum, float64(100), "run latency metric unexpectedly small")
func testIteratorFailure(t *testing.T) {
tes := taskExecutorSystem(t)
// replace iterator exhaust function with one which errors
tes.ex.workerPool = sync.Pool{New: func() interface{} {
return &worker{
e: tes.ex,
exhaustResultIterators: func(flux.Result) error {
return errors.New("something went wrong exhausting iterator")
systemBuildCompiler: NewASTCompiler,
nonSystemBuildCompiler: NewASTCompiler,
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
promiseID := platform.ID(promise.ID())
run, err := tes.i.FindRunByID(context.Background(), task.ID, promiseID)
if err != nil {
if run.ID != promiseID {
t.Fatal("promise and run dont match")
tes.svc.WaitForQueryLive(t, script)
if got := promise.Error(); got == nil {
t.Fatal("got no error when I should have")
func testErrorHandling(t *testing.T) {
tes := taskExecutorSystem(t)
metrics := tes.metrics
reg := prom.NewRegistry(zaptest.NewLogger(t))
script := fmt.Sprintf(fmtTestScript, t.Name())
ctx := icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script, Status: "active"})
if err != nil {
// encountering a bucket not found error should log an unrecoverable error in the metrics
forcedErr := errors.New("could not find bucket")
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err != nil {
mg := promtest.MustGather(t, reg)
m := promtest.MustFindMetric(t, mg, "task_executor_unrecoverable_counter", map[string]string{"taskID": task.ID.String(), "errorType": "internal error"})
if got := *m.Counter.Value; got != 1 {
t.Fatalf("expected 1 unrecoverable error, got %v", got)
// TODO (al): once user notification system is put in place, this code should be uncommented
// encountering a bucket not found error should deactivate the task
inactive, err := tes.i.FindTaskByID(context.Background(), task.ID)
if err != nil {
if inactive.Status != "inactive" {
t.Fatal("expected task to be deactivated after permanent error")
func TestPromiseFailure(t *testing.T) {
tes := taskExecutorSystem(t)
var (
script = fmt.Sprintf(fmtTestScript, t.Name())
ctx = icontext.SetAuthorizer(context.Background(), tes.tc.Auth)
span = opentracing.GlobalTracer().StartSpan("test-span")
ctx = opentracing.ContextWithSpan(ctx, span)
task, err := tes.i.CreateTask(ctx, taskmodel.TaskCreate{OrganizationID: tes.tc.OrgID, OwnerID: tes.tc.Auth.GetUserID(), Flux: script})
if err != nil {
if err := tes.i.DeleteTask(ctx, task.ID); err != nil {
promise, err := tes.ex.PromisedExecute(ctx, scheduler.ID(task.ID), time.Unix(123, 0), time.Unix(126, 0))
if err == nil {
t.Fatal("failed to error on promise create")
if promise != nil {
t.Fatalf("expected no promise but received one: %+v", promise)
runs, _, err := tes.i.FindRuns(context.Background(), taskmodel.RunFilter{Task: task.ID})
if err != nil {
if len(runs) != 1 {
t.Fatalf("expected 1 runs on failed promise: got: %d, %#v", len(runs), runs[0])
if runs[0].Status != "failed" {
t.Fatal("failed to set failed state")
type taskControlService struct {
run *taskmodel.Run
func (t *taskControlService) FinishRun(ctx context.Context, taskID platform.ID, runID platform.ID) (*taskmodel.Run, error) {
// ensure auth set on context
_, err := icontext.GetAuthorizer(ctx)
if err != nil {
t.run, err = t.TaskControlService.FinishRun(ctx, taskID, runID)
return t.run, err