feat(task): pass authorizer to query service
Immediately before the executor calls out to the query service, the executor loads the authorizer associated with the task, and associates that authorizer with the context used to execute the query.pull/11899/head
parent
caf08b5078
commit
d562d6bdde
|
|
@ -393,7 +393,7 @@ func (m *Launcher) run(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, boltStore)
|
||||
executor := taskexecutor.NewAsyncQueryServiceExecutor(m.logger.With(zap.String("service", "task-executor")), m.queryController, authSvc, boltStore)
|
||||
|
||||
lw := taskbackend.NewPointLogWriter(pointsWriter)
|
||||
m.scheduler = taskbackend.NewScheduler(boltStore, executor, lw, time.Now().UTC().Unix(), taskbackend.WithTicker(ctx, 100*time.Millisecond), taskbackend.WithLogger(m.logger))
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import (
|
|||
|
||||
"github.com/influxdata/flux"
|
||||
"github.com/influxdata/flux/lang"
|
||||
"github.com/influxdata/influxdb"
|
||||
icontext "github.com/influxdata/influxdb/context"
|
||||
"github.com/influxdata/influxdb/logger"
|
||||
"github.com/influxdata/influxdb/query"
|
||||
"github.com/influxdata/influxdb/task/backend"
|
||||
|
|
@ -17,7 +19,8 @@ import (
|
|||
|
||||
// queryServiceExecutor is an implementation of backend.Executor that depends on a QueryService.
|
||||
type queryServiceExecutor struct {
|
||||
svc query.QueryService
|
||||
qs query.QueryService
|
||||
as influxdb.AuthorizationService
|
||||
st backend.Store
|
||||
logger *zap.Logger
|
||||
wg sync.WaitGroup
|
||||
|
|
@ -28,17 +31,22 @@ var _ backend.Executor = (*queryServiceExecutor)(nil)
|
|||
// NewQueryServiceExecutor returns a new executor based on the given QueryService.
|
||||
// In general, you should prefer NewAsyncQueryServiceExecutor, as that code is smaller and simpler,
|
||||
// because asynchronous queries are more in line with the Executor interface.
|
||||
func NewQueryServiceExecutor(logger *zap.Logger, svc query.QueryService, st backend.Store) backend.Executor {
|
||||
return &queryServiceExecutor{logger: logger, svc: svc, st: st}
|
||||
func NewQueryServiceExecutor(logger *zap.Logger, qs query.QueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor {
|
||||
return &queryServiceExecutor{logger: logger, qs: qs, as: as, st: st}
|
||||
}
|
||||
|
||||
func (e *queryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) {
|
||||
t, err := e.st.FindTaskByID(ctx, run.TaskID)
|
||||
t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newSyncRunPromise(ctx, run, e, t), nil
|
||||
auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newSyncRunPromise(icontext.SetAuthorizer(ctx, auth), run, e, t), nil
|
||||
}
|
||||
|
||||
func (e *queryServiceExecutor) Wait() {
|
||||
|
|
@ -48,7 +56,7 @@ func (e *queryServiceExecutor) Wait() {
|
|||
// syncRunPromise implements backend.RunPromise for a synchronous QueryService.
|
||||
type syncRunPromise struct {
|
||||
qr backend.QueuedRun
|
||||
svc query.QueryService
|
||||
qs query.QueryService
|
||||
t *backend.StoreTask
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
|
@ -69,7 +77,7 @@ func newSyncRunPromise(ctx context.Context, qr backend.QueuedRun, e *queryServic
|
|||
log, logEnd := logger.NewOperation(opLogger, "Executing task", "execute")
|
||||
rp := &syncRunPromise{
|
||||
qr: qr,
|
||||
svc: e.svc,
|
||||
qs: e.qs,
|
||||
t: t,
|
||||
logger: log,
|
||||
logEnd: logEnd,
|
||||
|
|
@ -108,7 +116,7 @@ func (p *syncRunPromise) finish(res *runResult, err error) {
|
|||
defer p.logEnd()
|
||||
|
||||
// Always cancel p's context.
|
||||
// If finish is called before p.svc.Query completes, the query will be interrupted.
|
||||
// If finish is called before p.qs.Query completes, the query will be interrupted.
|
||||
// If afterwards, then p.cancel is just a resource cleanup.
|
||||
defer p.cancel()
|
||||
|
||||
|
|
@ -140,7 +148,7 @@ func (p *syncRunPromise) doQuery(wg *sync.WaitGroup) {
|
|||
Spec: spec,
|
||||
},
|
||||
}
|
||||
it, err := p.svc.Query(p.ctx, req)
|
||||
it, err := p.qs.Query(p.ctx, req)
|
||||
if err != nil {
|
||||
// Assume the error should not be part of the runResult.
|
||||
p.finish(nil, err)
|
||||
|
|
@ -177,7 +185,8 @@ func (p *syncRunPromise) cancelOnContextDone(wg *sync.WaitGroup) {
|
|||
|
||||
// asyncQueryServiceExecutor is an implementation of backend.Executor that depends on an AsyncQueryService.
|
||||
type asyncQueryServiceExecutor struct {
|
||||
svc query.AsyncQueryService
|
||||
qs query.AsyncQueryService
|
||||
as influxdb.AuthorizationService
|
||||
st backend.Store
|
||||
logger *zap.Logger
|
||||
wg sync.WaitGroup
|
||||
|
|
@ -186,12 +195,17 @@ type asyncQueryServiceExecutor struct {
|
|||
var _ backend.Executor = (*asyncQueryServiceExecutor)(nil)
|
||||
|
||||
// NewQueryServiceExecutor returns a new executor based on the given AsyncQueryService.
|
||||
func NewAsyncQueryServiceExecutor(logger *zap.Logger, svc query.AsyncQueryService, st backend.Store) backend.Executor {
|
||||
return &asyncQueryServiceExecutor{logger: logger, svc: svc, st: st}
|
||||
func NewAsyncQueryServiceExecutor(logger *zap.Logger, qs query.AsyncQueryService, as influxdb.AuthorizationService, st backend.Store) backend.Executor {
|
||||
return &asyncQueryServiceExecutor{logger: logger, qs: qs, as: as, st: st}
|
||||
}
|
||||
|
||||
func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.QueuedRun) (backend.RunPromise, error) {
|
||||
t, err := e.st.FindTaskByID(ctx, run.TaskID)
|
||||
t, m, err := e.st.FindTaskByIDWithMeta(ctx, run.TaskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
auth, err := e.as.FindAuthorizationByID(ctx, influxdb.ID(m.AuthorizationID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -207,7 +221,8 @@ func (e *asyncQueryServiceExecutor) Execute(ctx context.Context, run backend.Que
|
|||
Spec: spec,
|
||||
},
|
||||
}
|
||||
q, err := e.svc.Query(ctx, req)
|
||||
// Only set the authorizer on the context where we need it here.
|
||||
q, err := e.qs.Query(icontext.SetAuthorizer(ctx, auth), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,11 +16,12 @@ import (
|
|||
"github.com/influxdata/flux/memory"
|
||||
"github.com/influxdata/flux/values"
|
||||
platform "github.com/influxdata/influxdb"
|
||||
icontext "github.com/influxdata/influxdb/context"
|
||||
"github.com/influxdata/influxdb/inmem"
|
||||
"github.com/influxdata/influxdb/query"
|
||||
_ "github.com/influxdata/influxdb/query/builtin"
|
||||
"github.com/influxdata/influxdb/task/backend"
|
||||
"github.com/influxdata/influxdb/task/backend/executor"
|
||||
platformtesting "github.com/influxdata/influxdb/testing"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
|
@ -28,6 +29,9 @@ type fakeQueryService struct {
|
|||
mu sync.Mutex
|
||||
queries map[string]*fakeQuery
|
||||
queryErr error
|
||||
// The most recent ctx received in the Query method.
|
||||
// Used to validate that the executor applied the correct authorizer.
|
||||
mostRecentCtx context.Context
|
||||
}
|
||||
|
||||
var _ query.AsyncQueryService = (*fakeQueryService)(nil)
|
||||
|
|
@ -55,6 +59,7 @@ func newFakeQueryService() *fakeQueryService {
|
|||
func (s *fakeQueryService) Query(ctx context.Context, req *query.Request) (flux.Query, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.mostRecentCtx = ctx
|
||||
if s.queryErr != nil {
|
||||
err := s.queryErr
|
||||
s.queryErr = nil
|
||||
|
|
@ -226,6 +231,9 @@ type system struct {
|
|||
svc *fakeQueryService
|
||||
st backend.Store
|
||||
ex backend.Executor
|
||||
// We really just want an authorization service here, but we take a whole inmem service
|
||||
// to ensure that the authorization service validates org and user existence properly.
|
||||
i *inmem.Service
|
||||
}
|
||||
|
||||
type createSysFn func() *system
|
||||
|
|
@ -233,17 +241,20 @@ type createSysFn func() *system
|
|||
func createAsyncSystem() *system {
|
||||
svc := newFakeQueryService()
|
||||
st := backend.NewInMemStore()
|
||||
i := inmem.NewService()
|
||||
return &system{
|
||||
name: "AsyncExecutor",
|
||||
svc: svc,
|
||||
st: st,
|
||||
ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, st),
|
||||
ex: executor.NewAsyncQueryServiceExecutor(zap.NewNop(), svc, i, st),
|
||||
i: i,
|
||||
}
|
||||
}
|
||||
|
||||
func createSyncSystem() *system {
|
||||
svc := newFakeQueryService()
|
||||
st := backend.NewInMemStore()
|
||||
i := inmem.NewService()
|
||||
return &system{
|
||||
name: "SynchronousExecutor",
|
||||
svc: svc,
|
||||
|
|
@ -253,8 +264,10 @@ func createSyncSystem() *system {
|
|||
query.QueryServiceBridge{
|
||||
AsyncQueryService: svc,
|
||||
},
|
||||
i,
|
||||
st,
|
||||
),
|
||||
i: i,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -281,15 +294,13 @@ option task = {
|
|||
from(bucket: "one") |> http.to(url: "http://example.com")`
|
||||
|
||||
func testExecutorQuerySuccess(t *testing.T, fn createSysFn) {
|
||||
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
|
||||
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
|
||||
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
|
||||
sys := fn()
|
||||
tc := createCreds(t, sys.i)
|
||||
t.Run(sys.name+"/QuerySuccess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -333,18 +344,25 @@ func testExecutorQuerySuccess(t *testing.T, fn createSysFn) {
|
|||
if !reflect.DeepEqual(res, res2) {
|
||||
t.Fatalf("second call to wait returned a different result: %#v", res2)
|
||||
}
|
||||
|
||||
// The query must have received the appropriate authorizer.
|
||||
qa, err := icontext.GetAuthorizer(sys.svc.mostRecentCtx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if qa.Identifier() != tc.AuthzID {
|
||||
t.Fatalf("expected query authorizer to have ID %v, got %v", tc.AuthzID, qa.Identifier())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testExecutorQueryFailure(t *testing.T, fn createSysFn) {
|
||||
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
|
||||
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
|
||||
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
|
||||
sys := fn()
|
||||
tc := createCreds(t, sys.i)
|
||||
t.Run(sys.name+"/QueryFail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -368,14 +386,12 @@ func testExecutorQueryFailure(t *testing.T, fn createSysFn) {
|
|||
}
|
||||
|
||||
func testExecutorPromiseCancel(t *testing.T, fn createSysFn) {
|
||||
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
|
||||
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
|
||||
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
|
||||
sys := fn()
|
||||
tc := createCreds(t, sys.i)
|
||||
t.Run(sys.name+"/PromiseCancel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -398,14 +414,12 @@ func testExecutorPromiseCancel(t *testing.T, fn createSysFn) {
|
|||
}
|
||||
|
||||
func testExecutorServiceError(t *testing.T, fn createSysFn) {
|
||||
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
|
||||
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
|
||||
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
|
||||
sys := fn()
|
||||
tc := createCreds(t, sys.i)
|
||||
t.Run(sys.name+"/ServiceError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(context.Background(), backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -440,10 +454,6 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
// but it needs to be large-ish for slow machines running with the race detector.
|
||||
const waitCheckDelay = 100 * time.Millisecond
|
||||
|
||||
var orgID = platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa")
|
||||
var userID = platformtesting.MustIDBase16("baaaaaaaaaaaaaab")
|
||||
var authzID = platformtesting.MustIDBase16("caaaaaaaaaaaaaac")
|
||||
|
||||
// Other executor tests create a single sys and share it among subtests.
|
||||
// For this set of tests, we are testing Wait, which does not allow calling Execute concurrently,
|
||||
// so we make a new sys for each subtest.
|
||||
|
|
@ -470,12 +480,13 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
t.Run("cancel execute context", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sys := createSys()
|
||||
tc := createCreds(t, sys.i)
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
defer ctxCancel()
|
||||
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -510,11 +521,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
t.Run("cancel run promise", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sys := createSys()
|
||||
tc := createCreds(t, sys.i)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -550,11 +562,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
t.Run("run success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sys := createSys()
|
||||
tc := createCreds(t, sys.i)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -590,11 +603,12 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
t.Run("run failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sys := createSys()
|
||||
tc := createCreds(t, sys.i)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
script := fmt.Sprintf(fmtTestScript, t.Name())
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: orgID, User: userID, AuthorizationID: authzID, Script: script})
|
||||
tid, err := sys.st.CreateTask(ctx, backend.CreateTaskRequest{Org: tc.OrgID, User: tc.UserID, AuthorizationID: tc.AuthzID, Script: script})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -628,3 +642,40 @@ func testExecutorWait(t *testing.T, createSys createSysFn) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
type testCreds struct {
|
||||
OrgID, UserID, AuthzID platform.ID
|
||||
}
|
||||
|
||||
func createCreds(t *testing.T, i *inmem.Service) testCreds {
|
||||
t.Helper()
|
||||
|
||||
org := &platform.Organization{Name: t.Name() + "-org"}
|
||||
if err := i.CreateOrganization(context.Background(), org); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user := &platform.User{Name: t.Name() + "-user"}
|
||||
if err := i.CreateUser(context.Background(), user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
readPerm, err := platform.NewGlobalPermission(platform.ReadAction, platform.BucketsResourceType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writePerm, err := platform.NewGlobalPermission(platform.WriteAction, platform.BucketsResourceType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
auth := &platform.Authorization{
|
||||
OrgID: org.ID,
|
||||
UserID: user.ID,
|
||||
Permissions: []platform.Permission{*readPerm, *writePerm},
|
||||
}
|
||||
if err := i.CreateAuthorization(context.Background(), auth); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return testCreds{OrgID: org.ID, UserID: user.ID, AuthzID: auth.ID}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue