feat(task): pass authorizer to query service

Immediately before the executor calls out to the query service, the
executor loads the authorizer associated with the task, and associates
that authorizer with the context used to execute the query.
pull/11899/head
Mark Rushakoff 2019-02-07 16:21:50 -08:00 committed by Mark Rushakoff
parent caf08b5078
commit d562d6bdde
3 changed files with 107 additions and 41 deletions

View File

@ -393,7 +393,7 @@ func (m *Launcher) run(ctx context.Context) (err error) {
return err
}
executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, boltStore)
executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, authSvc, boltStore)
lw := taskbackend.NewPointLogWriter(pointsWriter)
m.scheduler = taskbackend.NewScheduler(boltStore, executor, lw, time.Now().UTC().Unix(), taskbackend.WithTicker(ctx, 100*time.Millisecond), taskbackend.WithLogger(m.logger))

View File

@ -9,6 +9,8 @@ import (
"github.com/influxdata/flux"
"github.com/influxdata/flux/lang"
"github.com/influxdata/influxdb"
icontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/logger"
"github.com/influxdata/influxdb/query"
"github.com/influxdata/influxdb/task/backend"
@ -17,7 +19,8 @@ import (
// queryServiceExecutor is an implementation of backend.Executor that depends on a QueryService.
type queryServiceExecutor struct {
svc query.QueryService
qs query.QueryService
as influxdb.AuthorizationService
st backend.Store
logger *zap.Logger
wg sync.WaitGroup
@ -28,17 +31,22 @@ var _ backend.Executor = (*queryServiceExecutor)(nil)
// NewQueryServiceExecutor returns a new executor based on the given QueryService.
// In general, you should prefer NewAsyncQueryServiceExecutor, as that code is smaller and simpler,
// because asynchronous queries are more in line with the Executor interface.
func NewQueryServiceExecutor(logger *zap.Logger, svc query.QueryService, st backend.Store) backend.Executor {
return &queryServiceExecutor{logger: logger, svc: svc, st: st}
func NewQueryServiceExecutor(logger *zap.Logger, qs query.QueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor {
return &queryServiceExecutor{logger: logger, qs: qs, as: as, st: st}
}
func (e *queryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) {
t, err := e.st.FindTaskByID(ctx, run.TaskID)
t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID)
if err != nil {
return nil, err
}
return newSyncRunPromise(ctx, run, e, t), nil
auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID))
if err != nil {
return nil, err
}
return newSyncRunPromise(icontext.SetAuthorizer(ctx, auth), run, e, t), nil
}
func (e *queryServiceExecutor) Wait() {
@ -48,7 +56,7 @@ func (e *queryServiceExecutor) Wait() {
// syncRunPromise implements backend.RunPromise for a synchronous QueryService.
type syncRunPromise struct {
qr backend.QueuedRun
svc query.QueryService
qs query.QueryService
t *backend.StoreTask
ctx context.Context
cancel context.CancelFunc
@ -69,7 +77,7 @@ func newSyncRunPromise(ctx context.Context, qr backend.QueuedRun, e *queryServic
log, logEnd := logger.NewOperation(opLogger, "Executing task", "execute")
rp := &syncRunPromise{
qr: qr,
svc: e.svc,
qs: e.qs,
t: t,
logger: log,
logEnd: logEnd,
@ -108,7 +116,7 @@ func (p *syncRunPromise) finish(res *runResult, err error) {
defer p.logEnd()
// Always cancel p's context.
// If finish is called before p.svc.Query completes, the query will be interrupted.
// If finish is called before p.qs.Query completes, the query will be interrupted.
// If afterwards, then p.cancel is just a resource cleanup.
defer p.cancel()
@ -140,7 +148,7 @@ func (p *syncRunPromise) doQuery(wg *sync.WaitGroup) {
Spec: spec,
},
}
it, err := p.svc.Query(p.ctx, req)
it, err := p.qs.Query(p.ctx, req)
if err != nil {
// Assume the error should not be part of the runResult.
p.finish(nil, err)
@ -177,7 +185,8 @@ func (p *syncRunPromise) cancelOnContextDone(wg *sync.WaitGroup) {
// asyncQueryServiceExecutor is an implementation of backend.Executor that depends on an AsyncQueryService.
type asyncQueryServiceExecutor struct {
svc query.AsyncQueryService
qs query.AsyncQueryService
as influxdb.AuthorizationService
st backend.Store
logger *zap.Logger
wg sync.WaitGroup
@ -186,12 +195,17 @@ type asyncQueryServiceExecutor struct {
var _ backend.Executor = (*asyncQueryServiceExecutor)(nil)
// NewQueryServiceExecutor returns a new executor based on the given AsyncQueryService.
func NewAsyncQueryServiceExecutor(logger *zap.Logger, svc query.AsyncQueryService, st backend.Store) backend.Executor {
return &asyncQueryServiceExecutor{logger: logger, svc: svc, st: st}
func NewAsyncQueryServiceExecutor(logger *zap.Logger, qs query.AsyncQueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor {
return &asyncQueryServiceExecutor{logger: logger, qs: qs, as: as, st: st}
}
func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) {
t, err := e.st.FindTaskByID(ctx, run.TaskID)
t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID)
if err != nil {
return nil, err
}
auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID))
if err != nil {
return nil, err
}
@ -207,7 +221,8 @@ func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.Que
Spec: spec,
},
}
q, err := e.svc.Query(ctx, req)
// Only set the authorizer on the context where we need it here.
q, err := e.qs.Query(icontext.SetAuthorizer(ctx, auth), req)
if err != nil {
return nil, err
}

View File

@ -16,11 +16,12 @@ import (
"github.com/influxdata/flux/memory"
"github.com/influxdata/flux/values"
platform "github.com/influxdata/influxdb"
icontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/inmem"
"github.com/influxdata/influxdb/query"
_ "github.com/influxdata/influxdb/query/builtin"
"github.com/influxdata/influxdb/task/backend"
"github.com/influxdata/influxdb/task/backend/executor"
platformtesting "github.com/influxdata/influxdb/testing"
"go.uber.org/zap"
)
@ -28,6 +29,9 @@ type fakeQueryService struct {
mu sync.Mutex
queries map[string]*fakeQuery
queryErr error
// The most recent ctx received in the Query method.
// Used to validate that the executor applied the correct authorizer.
mostRecentCtx context.Context
}
var _ query.AsyncQueryService = (*fakeQueryService)(nil)
@ -55,6 +59,7 @@ func newFakeQueryService() *fakeQueryService {
func (s *fakeQueryService) Query(ctx context.Context, req *query.Request) (flux.Query, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.mostRecentCtx = ctx
if s.queryErr != nil {
err := s.queryErr
s.queryErr = nil
@ -226,6 +231,9 @@ type system struct {
svc *fakeQueryService
st backend.Store
ex backend.Executor
// We really just want an authorization service here, but we take a whole inmem service
// to ensure that the authorization service validates org and user existence properly.
i *inmem.Service
}
type createSysFn func() *system
@ -233,17 +241,20 @@ type createSysFn func() *system
func createAsyncSystem() *system {
svc := newFakeQueryService()
st := backend.NewInMemStore()
i := inmem.NewService()
return &system{
name: "AsyncExecutor",
svc: svc,
st: st,
ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, st),
ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, i, st),
i: i,
}
}
func createSyncSystem() *system {
svc := newFakeQueryService()
st := backend.NewInMemStore()
i := inmem.NewService()
return &system{
name: "SynchronousExecutor",
svc: svc,
@ -253,8 +264,10 @@ func createSyncSystem() *system {
query.QueryServiceBridge{
AsyncQueryService: svc,
},
i,
st,
),
i: i,
}
}
@ -281,15 +294,13 @@ option task = {
from(bucket: "one") |> http.to(url: "http://example.com")`
func testExecutorQuerySuccess(t *testing.T, fn createSysFn) {
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
sys := fn()
tc := createCreds(t, sys.i)
t.Run(sys.name+"/QuerySuccess", func(t *testing.T) {
t.Parallel()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -333,18 +344,25 @@ func testExecutorQuerySuccess(t *testing.T, fn createSysFn) {
if !reflect.DeepEqual(res, res2) {
t.Fatalf("second call to wait returned a different result: %#v", res2)
}
// The query must have received the appropriate authorizer.
qa, err := icontext.GetAuthorizer(sys.svc.mostRecentCtx)
if err != nil {
t.Fatal(err)
}
if qa.Identifier() != tc.AuthzID {
t.Fatalf("expected query authorizer to have ID %v, got %v", tc.AuthzID, qa.Identifier())
}
})
}
func testExecutorQueryFailure(t *testing.T, fn createSysFn) {
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
sys := fn()
tc := createCreds(t, sys.i)
t.Run(sys.name+"/QueryFail", func(t *testing.T) {
t.Parallel()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -368,14 +386,12 @@ func testExecutorQueryFailure(t *testing.T, fn createSysFn) {
}
func testExecutorPromiseCancel(t *testing.T, fn createSysFn) {
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
sys := fn()
tc := createCreds(t, sys.i)
t.Run(sys.name+"/PromiseCancel", func(t *testing.T) {
t.Parallel()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -398,14 +414,12 @@ func testExecutorPromiseCancel(t *testing.T, fn createSysFn) {
}
func testExecutorServiceError(t *testing.T, fn createSysFn) {
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
sys := fn()
tc := createCreds(t, sys.i)
t.Run(sys.name+"/ServiceError", func(t *testing.T) {
t.Parallel()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -440,10 +454,6 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
// but it needs to be large-ish for slow machines running with the race detector.
const waitCheckDelay = 100 * time.Millisecond
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
// Other executor tests create a single sys and share it among subtests.
// For this set of tests, we are testing Wait, which does not allow calling Execute concurrently,
// so we make a new sys for each subtest.
@ -470,12 +480,13 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
t.Run("cancel execute context", func(t *testing.T) {
t.Parallel()
sys := createSys()
tc := createCreds(t, sys.i)
ctx, ctxCancel := context.WithCancel(context.Background())
defer ctxCancel()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -510,11 +521,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
t.Run("cancel run promise", func(t *testing.T) {
t.Parallel()
sys := createSys()
tc := createCreds(t, sys.i)
ctx := context.Background()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -550,11 +562,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
t.Run("run success", func(t *testing.T) {
t.Parallel()
sys := createSys()
tc := createCreds(t, sys.i)
ctx := context.Background()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -590,11 +603,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
t.Run("run failure", func(t *testing.T) {
t.Parallel()
sys := createSys()
tc := createCreds(t, sys.i)
ctx := context.Background()
script := fmt.Sprintf(fmtTestScript, t.Name())
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
if err != nil {
t.Fatal(err)
}
@ -628,3 +642,40 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
})
})
}
type testCreds struct {
OrgID, UserID, AuthzID platform.ID
}
func createCreds(t *testing.T, i *inmem.Service) testCreds {
t.Helper()
org := &platform.Organization{Name: t.Name() + "-org"}
if err := i.CreateOrganization(context.Background(), org); err != nil {
t.Fatal(err)
}
user := &platform.User{Name: t.Name() + "-user"}
if err := i.CreateUser(context.Background(), user); err != nil {
t.Fatal(err)
}
readPerm, err := platform.NewGlobalPermission(platform.ReadAction, platform.BucketsResourceType)
if err != nil {
t.Fatal(err)
}
writePerm, err := platform.NewGlobalPermission(platform.WriteAction, platform.BucketsResourceType)
if err != nil {
t.Fatal(err)
}
auth := &platform.Authorization{
OrgID: org.ID,
UserID: user.ID,
Permissions: []platform.Permission{*readPerm, *writePerm},
}
if err := i.CreateAuthorization(context.Background(), auth); err != nil {
t.Fatal(err)
}
return testCreds{OrgID: org.ID, UserID: user.ID, AuthzID: auth.ID}
}