diff --git a/cmd/influxd/launcher/launcher.go b/cmd/influxd/launcher/launcher.go index c0905bee9b..4f8345c5da 100644 --- a/cmd/influxd/launcher/launcher.go +++ b/cmd/influxd/launcher/launcher.go @@ -393,7 +393,7 @@ func (m *Launcher) run(ctx context.Context) (err error) { return err } - executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, boltStore) + executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, authSvc, boltStore) lw := taskbackend.NewPointLogWriter(pointsWriter) m.scheduler = taskbackend.NewScheduler(boltStore, executor, lw, time.Now().UTC().Unix(), taskbackend.WithTicker(ctx, 100*time.Millisecond), taskbackend.WithLogger(m.logger)) diff --git a/task/backend/executor/executor.go b/task/backend/executor/executor.go index d70a40c05f..9d11905c36 100644 --- a/task/backend/executor/executor.go +++ b/task/backend/executor/executor.go @@ -9,6 +9,8 @@ import ( "github.com/influxdata/flux" "github.com/influxdata/flux/lang" + "github.com/influxdata/influxdb" + icontext "github.com/influxdata/influxdb/context" "github.com/influxdata/influxdb/logger" "github.com/influxdata/influxdb/query" "github.com/influxdata/influxdb/task/backend" @@ -17,7 +19,8 @@ import ( // queryServiceExecutor is an implementation of backend.Executor that depends on a QueryService. type queryServiceExecutor struct { - svc query.QueryService + qs query.QueryService + as influxdb.AuthorizationService st backend.Store logger *zap.Logger wg sync.WaitGroup @@ -28,17 +31,22 @@ var _ backend.Executor = (*queryServiceExecutor)(nil) // NewQueryServiceExecutor returns a new executor based on the given QueryService. // In general, you should prefer NewAsyncQueryServiceExecutor, as that code is smaller and simpler, // because asynchronous queries are more in line with the Executor interface. -func NewQueryServiceExecutor(logger *zap.Logger, svc query.QueryService, st backend.Store) backend.Executor { - return &queryServiceExecutor{logger: logger, svc: svc, st: st} +func NewQueryServiceExecutor(logger *zap.Logger, qs query.QueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor { + return &queryServiceExecutor{logger: logger, qs: qs, as: as, st: st} } func (e *queryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) { - t, err := e.st.FindTaskByID(ctx, run.TaskID) + t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID) if err != nil { return nil, err } - return newSyncRunPromise(ctx, run, e, t), nil + auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID)) + if err != nil { + return nil, err + } + + return newSyncRunPromise(icontext.SetAuthorizer(ctx, auth), run, e, t), nil } func (e *queryServiceExecutor) Wait() { @@ -48,7 +56,7 @@ func (e *queryServiceExecutor) Wait() { // syncRunPromise implements backend.RunPromise for a synchronous QueryService. type syncRunPromise struct { qr backend.QueuedRun - svc query.QueryService + qs query.QueryService t *backend.StoreTask ctx context.Context cancel context.CancelFunc @@ -69,7 +77,7 @@ func newSyncRunPromise(ctx context.Context, qr backend.QueuedRun, e *queryServic log, logEnd := logger.NewOperation(opLogger, "Executing task", "execute") rp := &syncRunPromise{ qr: qr, - svc: e.svc, + qs: e.qs, t: t, logger: log, logEnd: logEnd, @@ -108,7 +116,7 @@ func (p *syncRunPromise) finish(res *runResult, err error) { defer p.logEnd() // Always cancel p's context. - // If finish is called before p.svc.Query completes, the query will be interrupted. + // If finish is called before p.qs.Query completes, the query will be interrupted. // If afterwards, then p.cancel is just a resource cleanup. defer p.cancel() @@ -140,7 +148,7 @@ func (p *syncRunPromise) doQuery(wg *sync.WaitGroup) { Spec: spec, }, } - it, err := p.svc.Query(p.ctx, req) + it, err := p.qs.Query(p.ctx, req) if err != nil { // Assume the error should not be part of the runResult. p.finish(nil, err) @@ -177,7 +185,8 @@ func (p *syncRunPromise) cancelOnContextDone(wg *sync.WaitGroup) { // asyncQueryServiceExecutor is an implementation of backend.Executor that depends on an AsyncQueryService. type asyncQueryServiceExecutor struct { - svc query.AsyncQueryService + qs query.AsyncQueryService + as influxdb.AuthorizationService st backend.Store logger *zap.Logger wg sync.WaitGroup @@ -186,12 +195,17 @@ type asyncQueryServiceExecutor struct { var _ backend.Executor = (*asyncQueryServiceExecutor)(nil) // NewQueryServiceExecutor returns a new executor based on the given AsyncQueryService. -func NewAsyncQueryServiceExecutor(logger *zap.Logger, svc query.AsyncQueryService, st backend.Store) backend.Executor { - return &asyncQueryServiceExecutor{logger: logger, svc: svc, st: st} +func NewAsyncQueryServiceExecutor(logger *zap.Logger, qs query.AsyncQueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor { + return &asyncQueryServiceExecutor{logger: logger, qs: qs, as: as, st: st} } func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) { - t, err := e.st.FindTaskByID(ctx, run.TaskID) + t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID) + if err != nil { + return nil, err + } + + auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID)) if err != nil { return nil, err } @@ -207,7 +221,8 @@ func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.Que Spec: spec, }, } - q, err := e.svc.Query(ctx, req) + // Only set the authorizer on the context where we need it here. + q, err := e.qs.Query(icontext.SetAuthorizer(ctx, auth), req) if err != nil { return nil, err } diff --git a/task/backend/executor/executor_test.go b/task/backend/executor/executor_test.go index dd55dca8a5..682d7f7340 100644 --- a/task/backend/executor/executor_test.go +++ b/task/backend/executor/executor_test.go @@ -16,11 +16,12 @@ import ( "github.com/influxdata/flux/memory" "github.com/influxdata/flux/values" platform "github.com/influxdata/influxdb" + icontext "github.com/influxdata/influxdb/context" + "github.com/influxdata/influxdb/inmem" "github.com/influxdata/influxdb/query" _ "github.com/influxdata/influxdb/query/builtin" "github.com/influxdata/influxdb/task/backend" "github.com/influxdata/influxdb/task/backend/executor" - platformtesting "github.com/influxdata/influxdb/testing" "go.uber.org/zap" ) @@ -28,6 +29,9 @@ type fakeQueryService struct { mu sync.Mutex queries map[string]*fakeQuery queryErr error + // The most recent ctx received in the Query method. + // Used to validate that the executor applied the correct authorizer. + mostRecentCtx context.Context } var _ query.AsyncQueryService = (*fakeQueryService)(nil) @@ -55,6 +59,7 @@ func newFakeQueryService() *fakeQueryService { func (s *fakeQueryService) Query(ctx context.Context, req *query.Request) (flux.Query, error) { s.mu.Lock() defer s.mu.Unlock() + s.mostRecentCtx = ctx if s.queryErr != nil { err := s.queryErr s.queryErr = nil @@ -226,6 +231,9 @@ type system struct { svc *fakeQueryService st backend.Store ex backend.Executor + // We really just want an authorization service here, but we take a whole inmem service + // to ensure that the authorization service validates org and user existence properly. + i *inmem.Service } type createSysFn func() *system @@ -233,17 +241,20 @@ type createSysFn func() *system func createAsyncSystem() *system { svc := newFakeQueryService() st := backend.NewInMemStore() + i := inmem.NewService() return &system{ name: "AsyncExecutor", svc: svc, st: st, - ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, st), + ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, i, st), + i: i, } } func createSyncSystem() *system { svc := newFakeQueryService() st := backend.NewInMemStore() + i := inmem.NewService() return &system{ name: "SynchronousExecutor", svc: svc, @@ -253,8 +264,10 @@ func createSyncSystem() *system { query.QueryServiceBridge{ AsyncQueryService: svc, }, + i, st, ), + i: i, } } @@ -281,15 +294,13 @@ option task = { from(bucket: "one") |> http.to(url: "http://example.com")` func testExecutorQuerySuccess(t *testing.T, fn createSysFn) { - var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa") - var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab") - var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac") sys := fn() + tc := createCreds(t, sys.i) t.Run(sys.name+"/QuerySuccess", func(t *testing.T) { t.Parallel() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -333,18 +344,25 @@ func testExecutorQuerySuccess(t *testing.T, fn createSysFn) { if !reflect.DeepEqual(res, res2) { t.Fatalf("second call to wait returned a different result: %#v", res2) } + + // The query must have received the appropriate authorizer. + qa, err := icontext.GetAuthorizer(sys.svc.mostRecentCtx) + if err != nil { + t.Fatal(err) + } + if qa.Identifier() != tc.AuthzID { + t.Fatalf("expected query authorizer to have ID %v, got %v", tc.AuthzID, qa.Identifier()) + } }) } func testExecutorQueryFailure(t *testing.T, fn createSysFn) { - var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa") - var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab") - var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac") sys := fn() + tc := createCreds(t, sys.i) t.Run(sys.name+"/QueryFail", func(t *testing.T) { t.Parallel() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -368,14 +386,12 @@ func testExecutorQueryFailure(t *testing.T, fn createSysFn) { } func testExecutorPromiseCancel(t *testing.T, fn createSysFn) { - var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa") - var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab") - var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac") sys := fn() + tc := createCreds(t, sys.i) t.Run(sys.name+"/PromiseCancel", func(t *testing.T) { t.Parallel() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -398,14 +414,12 @@ func testExecutorPromiseCancel(t *testing.T, fn createSysFn) { } func testExecutorServiceError(t *testing.T, fn createSysFn) { - var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa") - var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab") - var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac") sys := fn() + tc := createCreds(t, sys.i) t.Run(sys.name+"/ServiceError", func(t *testing.T) { t.Parallel() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -440,10 +454,6 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { // but it needs to be large-ish for slow machines running with the race detector. const waitCheckDelay = 100 * time.Millisecond - var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa") - var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab") - var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac") - // Other executor tests create a single sys and share it among subtests. // For this set of tests, we are testing Wait, which does not allow calling Execute concurrently, // so we make a new sys for each subtest. @@ -470,12 +480,13 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { t.Run("cancel execute context", func(t *testing.T) { t.Parallel() sys := createSys() + tc := createCreds(t, sys.i) ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -510,11 +521,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { t.Run("cancel run promise", func(t *testing.T) { t.Parallel() sys := createSys() + tc := createCreds(t, sys.i) ctx := context.Background() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -550,11 +562,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { t.Run("run success", func(t *testing.T) { t.Parallel() sys := createSys() + tc := createCreds(t, sys.i) ctx := context.Background() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -590,11 +603,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { t.Run("run failure", func(t *testing.T) { t.Parallel() sys := createSys() + tc := createCreds(t, sys.i) ctx := context.Background() script := fmt.Sprintf(fmtTestScript, t.Name()) - tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script}) + tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script}) if err != nil { t.Fatal(err) } @@ -628,3 +642,40 @@ func testExecutorWait(t *testing.T, createSys createSysFn) { }) }) } + +type testCreds struct { + OrgID, UserID, AuthzID platform.ID +} + +func createCreds(t *testing.T, i *inmem.Service) testCreds { + t.Helper() + + org := &platform.Organization{Name: t.Name() + "-org"} + if err := i.CreateOrganization(context.Background(), org); err != nil { + t.Fatal(err) + } + + user := &platform.User{Name: t.Name() + "-user"} + if err := i.CreateUser(context.Background(), user); err != nil { + t.Fatal(err) + } + + readPerm, err := platform.NewGlobalPermission(platform.ReadAction, platform.BucketsResourceType) + if err != nil { + t.Fatal(err) + } + writePerm, err := platform.NewGlobalPermission(platform.WriteAction, platform.BucketsResourceType) + if err != nil { + t.Fatal(err) + } + auth := &platform.Authorization{ + OrgID: org.ID, + UserID: user.ID, + Permissions: []platform.Permission{*readPerm, *writePerm}, + } + if err := i.CreateAuthorization(context.Background(), auth); err != nil { + t.Fatal(err) + } + + return testCreds{OrgID: org.ID, UserID: user.ID, AuthzID: auth.ID} +}