fix(task): pass task's authorization to query system, if using sessions

The query system specifically expects an Authorization. When a request
comes in using a Session, use the target task's Authorization, if we are
allowed to read it, when executing a query against the system bucket.
pull/12121/head
Mark Rushakoff 2019-02-20 15:49:55 -08:00 committed by Mark Rushakoff
parent 17e318fd6d
commit f79d9cba4f
2 changed files with 646 additions and 113 deletions

View File

@ -769,6 +769,29 @@ func (h *TaskHandler) handleGetLogs(w http.ResponseWriter, r *http.Request) {
return
}
auth, err := pcontext.GetAuthorizer(ctx)
if err != nil {
err = &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "failed to get authorizer",
}
EncodeError(ctx, err, w)
return
}
if k := auth.Kind(); k != platform.AuthorizationKind {
// Get the authorization for the task, if allowed.
authz, err := h.getAuthorizationForTask(ctx, req.filter.Task)
if err != nil {
EncodeError(ctx, err, w)
return
}
// We were able to access the authorizer for the task, so reassign that on the context for the rest of this call.
ctx = pcontext.SetAuthorizer(ctx, authz)
}
logs, _, err := h.TaskService.FindLogs(ctx, req.filter)
if err != nil {
err := &platform.Error{
@ -834,6 +857,29 @@ func (h *TaskHandler) handleGetRuns(w http.ResponseWriter, r *http.Request) {
return
}
auth, err := pcontext.GetAuthorizer(ctx)
if err != nil {
err = &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "failed to get authorizer",
}
EncodeError(ctx, err, w)
return
}
if k := auth.Kind(); k != platform.AuthorizationKind {
// Get the authorization for the task, if allowed.
authz, err := h.getAuthorizationForTask(ctx, req.filter.Task)
if err != nil {
EncodeError(ctx, err, w)
return
}
// We were able to access the authorizer for the task, so reassign that on the context for the rest of this call.
ctx = pcontext.SetAuthorizer(ctx, authz)
}
runs, _, err := h.TaskService.FindRuns(ctx, req.filter)
if err != nil {
err := &platform.Error{
@ -1018,6 +1064,29 @@ func (h *TaskHandler) handleGetRun(w http.ResponseWriter, r *http.Request) {
return
}
auth, err := pcontext.GetAuthorizer(ctx)
if err != nil {
err = &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "failed to get authorizer",
}
EncodeError(ctx, err, w)
return
}
if k := auth.Kind(); k != platform.AuthorizationKind {
// Get the authorization for the task, if allowed.
authz, err := h.getAuthorizationForTask(ctx, req.TaskID)
if err != nil {
EncodeError(ctx, err, w)
return
}
// We were able to access the authorizer for the task, so reassign that on the context for the rest of this call.
ctx = pcontext.SetAuthorizer(ctx, authz)
}
run, err := h.TaskService.FindRunByID(ctx, req.TaskID, req.RunID)
if err != nil {
err := &platform.Error{
@ -1152,6 +1221,29 @@ func (h *TaskHandler) handleRetryRun(w http.ResponseWriter, r *http.Request) {
return
}
auth, err := pcontext.GetAuthorizer(ctx)
if err != nil {
err = &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "failed to get authorizer",
}
EncodeError(ctx, err, w)
return
}
if k := auth.Kind(); k != platform.AuthorizationKind {
// Get the authorization for the task, if allowed.
authz, err := h.getAuthorizationForTask(ctx, req.TaskID)
if err != nil {
EncodeError(ctx, err, w)
return
}
// We were able to access the authorizer for the task, so reassign that on the context for the rest of this call.
ctx = pcontext.SetAuthorizer(ctx, authz)
}
run, err := h.TaskService.RetryRun(ctx, req.TaskID, req.RunID)
if err != nil {
err := &platform.Error{
@ -1230,6 +1322,35 @@ func (h *TaskHandler) populateTaskCreateOrg(ctx context.Context, tc *platform.Ta
return nil
}
// getAuthorizationForTask looks up the authorization associated with taskID,
// ensuring that the authorizer on ctx is allowed to view the task and the authorization.
//
// This method returns a *platform.Error, suitable for directly passing to EncodeError.
func (h *TaskHandler) getAuthorizationForTask(ctx context.Context, taskID platform.ID) (*platform.Authorization, *platform.Error) {
// First look up the task, if we're allowed.
// This assumes h.TaskService validates access.
t, err := h.TaskService.FindTaskByID(ctx, taskID)
if err != nil {
return nil, &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "task ID unknown or unauthorized",
}
}
// Explicitly check against an authorized authorization service.
authz, err := authorizer.NewAuthorizationService(h.AuthorizationService).FindAuthorizationByID(ctx, t.AuthorizationID)
if err != nil {
return nil, &platform.Error{
Err: err,
Code: platform.EUnauthorized,
Msg: "unable to access task authorization",
}
}
return authz, nil
}
// TaskService connects to Influx via HTTP using tokens to manage tasks.
type TaskService struct {
Addr string

View File

@ -377,7 +377,7 @@ func TestTaskHandler_handleGetRun(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.TODO(),
context.Background(),
httprouter.ParamsKey,
httprouter.Params{
{
@ -389,6 +389,7 @@ func TestTaskHandler_handleGetRun(t *testing.T) {
Value: tt.args.runID.String(),
},
}))
r = r.WithContext(pcontext.SetAuthorizer(r.Context(), &platform.Authorization{Permissions: platform.OperPermissions()}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.TaskService = tt.fields.taskService
@ -490,7 +491,7 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.TODO(),
context.Background(),
httprouter.ParamsKey,
httprouter.Params{
{
@ -498,6 +499,7 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
Value: tt.args.taskID.String(),
},
}))
r = r.WithContext(pcontext.SetAuthorizer(r.Context(), &platform.Authorization{Permissions: platform.OperPermissions()}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.TaskService = tt.fields.taskService
@ -538,6 +540,9 @@ func TestTaskHandler_NotFoundStatus(t *testing.T) {
t.Fatal(err)
}
// Create a session to associate with the contexts, so authorization checks pass.
authz := &platform.Authorization{Permissions: platform.OperPermissions()}
const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA)
var (
@ -763,7 +768,9 @@ func TestTaskHandler_NotFoundStatus(t *testing.T) {
okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...)
t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body))
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body)).WithContext(
pcontext.SetAuthorizer(context.Background(), authz),
)
h.ServeHTTP(w, r)
@ -782,7 +789,9 @@ func TestTaskHandler_NotFoundStatus(t *testing.T) {
path := fmt.Sprintf(tc.pathFmt, nfa...)
t.Run(tc.method+" "+path, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body))
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body)).WithContext(
pcontext.SetAuthorizer(context.Background(), authz),
)
h.ServeHTTP(w, r)
@ -899,40 +908,10 @@ func TestService_handlePostTaskLabel(t *testing.T) {
}
}
func TestTaskHandler_CreateTaskFromSession(t *testing.T) {
func TestTaskHandler_Sessions(t *testing.T) {
// Common setup to get a working base for using tasks.
i := inmem.NewService()
taskID := platform.ID(9)
var createdTasks []platform.TaskCreate
ts := &mock.TaskService{
CreateTaskFn: func(_ context.Context, tc platform.TaskCreate) (*platform.Task, error) {
createdTasks = append(createdTasks, tc)
// Task with fake IDs so it can be serialized.
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: 999, Name: "x"}, nil
},
// Needed due to task authorization bootstrapping.
UpdateTaskFn: func(ctx context.Context, id platform.ID, tu platform.TaskUpdate) (*platform.Task, error) {
authz, err := i.FindAuthorizationByToken(ctx, tu.Token)
if err != nil {
t.Fatal(err)
}
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: authz.ID, Name: "x"}, nil
},
}
h := NewTaskHandler(&TaskBackend{
Logger: zaptest.NewLogger(t),
TaskService: ts,
AuthorizationService: i,
OrganizationService: i,
UserResourceMappingService: i,
LabelService: i,
UserService: i,
BucketService: i,
})
ctx := context.Background()
// Set up user and org.
@ -965,13 +944,53 @@ func TestTaskHandler_CreateTaskFromSession(t *testing.T) {
t.Fatal(err)
}
// Create a session for use in authorizing context.
s := &platform.Session{
sessionAllPermsCtx := pcontext.SetAuthorizer(context.Background(), &platform.Session{
UserID: u.ID,
Permissions: platform.OperPermissions(),
ExpiresAt: time.Now().Add(24 * time.Hour),
})
sessionNoPermsCtx := pcontext.SetAuthorizer(context.Background(), &platform.Session{
UserID: u.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
newHandler := func(t *testing.T, ts *mock.TaskService) *TaskHandler {
return NewTaskHandler(&TaskBackend{
Logger: zaptest.NewLogger(t),
TaskService: ts,
AuthorizationService: i,
OrganizationService: i,
UserResourceMappingService: i,
LabelService: i,
UserService: i,
BucketService: i,
})
}
t.Run("creating a task from a session", func(t *testing.T) {
taskID := platform.ID(9)
var createdTasks []platform.TaskCreate
ts := &mock.TaskService{
CreateTaskFn: func(_ context.Context, tc platform.TaskCreate) (*platform.Task, error) {
createdTasks = append(createdTasks, tc)
// Task with fake IDs so it can be serialized.
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: 999, Name: "x"}, nil
},
// Needed due to task authorization bootstrapping.
UpdateTaskFn: func(ctx context.Context, id platform.ID, tu platform.TaskUpdate) (*platform.Task, error) {
authz, err := i.FindAuthorizationByToken(ctx, tu.Token)
if err != nil {
t.Fatal(err)
}
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: authz.ID, Name: "x"}, nil
},
}
h := newHandler(t, ts)
url := "http://localhost:9999/api/v2/tasks"
b, err := json.Marshal(platform.TaskCreate{
Flux: `option task = {name:"x", every:1m} from(bucket:"b-src") |> range(start:-1m) |> to(bucket:"b-dst", org:"o")`,
OrganizationID: o.ID,
@ -980,10 +999,7 @@ func TestTaskHandler_CreateTaskFromSession(t *testing.T) {
t.Fatal(err)
}
sessionCtx := pcontext.SetAuthorizer(context.Background(), s)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks")
r := httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(sessionCtx)
r := httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(sessionAllPermsCtx)
w := httptest.NewRecorder()
h.handlePostTask(w, r)
@ -1011,12 +1027,14 @@ func TestTaskHandler_CreateTaskFromSession(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if authz.UserID != u.ID {
t.Fatalf("expected authorization to be associated with user %v, got %v", u.ID, authz.UserID)
}
if authz.OrgID != o.ID {
t.Fatalf("expected authorization to be associated with org %v, got %v", o.ID, authz.OrgID)
t.Fatalf("expected authorization to have org ID %v, got %v", o.ID, authz.OrgID)
}
if authz.UserID != u.ID {
t.Fatalf("expected authorization to have user ID %v, got %v", u.ID, authz.UserID)
}
const expDesc = `auto-generated authorization for task "x"`
if authz.Description != expDesc {
t.Fatalf("expected authorization to be created with description %q, got %q", expDesc, authz.Description)
@ -1056,4 +1074,398 @@ func TestTaskHandler_CreateTaskFromSession(t *testing.T) {
}) {
t.Fatalf("expected authorization to be allowed to read its task, but it wasn't allowed")
}
// Session without permissions should not be allowed to create task.
r = httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(sessionNoPermsCtx)
w = httptest.NewRecorder()
h.handlePostTask(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized or forbidden, got %v", res.StatusCode)
}
})
t.Run("get runs for a task", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findRunsCtx context.Context
ts := &mock.TaskService{
FindRunsFn: func(ctx context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
findRunsCtx = ctx
if f.Task != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, f.Task)
}
return []*platform.Run{
{ID: runID, TaskID: taskID},
}, 1, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs", taskID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{{Key: "id", Value: taskID.String()}})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetRuns(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindRuns must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findRunsCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("get single run for a task", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findRunByIDCtx context.Context
ts := &mock.TaskService{
FindRunByIDFn: func(ctx context.Context, tid, rid platform.ID) (*platform.Run, error) {
findRunByIDCtx = ctx
if tid != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, tid)
}
if rid != runID {
t.Fatalf("expected run ID %v, got %v", runID, rid)
}
return &platform.Run{ID: runID, TaskID: taskID}, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetRun(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindRunByID must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findRunByIDCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("get logs for a run", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findLogsCtx context.Context
ts := &mock.TaskService{
FindLogsFn: func(ctx context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
findLogsCtx = ctx
if f.Task != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, f.Task)
}
if *f.Run != runID {
t.Fatalf("expected run ID %v, got %v", runID, *f.Run)
}
line := platform.Log("a log line")
return []*platform.Log{&line}, 1, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s/logs", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetLogs(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindLogs must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findLogsCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("retry a run", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var retryRunCtx context.Context
ts := &mock.TaskService{
RetryRunFn: func(ctx context.Context, tid, rid platform.ID) (*platform.Run, error) {
retryRunCtx = ctx
if tid != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, tid)
}
if rid != runID {
t.Fatalf("expected run ID %v, got %v", runID, rid)
}
return &platform.Run{ID: 10 * runID, TaskID: taskID}, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s/retry", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("POST", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleRetryRun(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.RetryRun must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(retryRunCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("POST", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
}