influxdb/http/task_service_test.go

1805 lines
50 KiB
Go

package http
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
platform "github.com/influxdata/influxdb"
pcontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/inmem"
"github.com/influxdata/influxdb/mock"
_ "github.com/influxdata/influxdb/query/builtin"
"github.com/influxdata/influxdb/task/backend"
platformtesting "github.com/influxdata/influxdb/testing"
"github.com/julienschmidt/httprouter"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
// NewMockTaskBackend returns a TaskBackend with mock services.
func NewMockTaskBackend(t *testing.T) *TaskBackend {
return &TaskBackend{
Logger: zaptest.NewLogger(t).With(zap.String("handler", "task")),
AuthorizationService: mock.NewAuthorizationService(),
TaskService: &mock.TaskService{},
OrganizationService: &mock.OrganizationService{
FindOrganizationByIDF: func(ctx context.Context, id platform.ID) (*platform.Organization, error) {
return &platform.Organization{ID: id, Name: "test"}, nil
},
FindOrganizationF: func(ctx context.Context, filter platform.OrganizationFilter) (*platform.Organization, error) {
org := &platform.Organization{}
if filter.Name != nil {
if *filter.Name == "non-existent-org" {
return nil, &platform.Error{
Err: errors.New("org not found or unauthorized"),
Msg: "org " + *filter.Name + " not found or unauthorized",
Code: platform.ENotFound,
}
}
org.Name = *filter.Name
}
if filter.ID != nil {
org.ID = *filter.ID
}
return org, nil
},
},
UserResourceMappingService: inmem.NewService(),
LabelService: mock.NewLabelService(),
UserService: mock.NewUserService(),
}
}
func TestTaskHandler_handleGetTasks(t *testing.T) {
type fields struct {
taskService platform.TaskService
labelService platform.LabelService
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
getParams string
fields fields
wants wants
}{
{
name: "get tasks",
fields: fields{
taskService: &mock.TaskService{
FindTasksFn: func(ctx context.Context, f platform.TaskFilter) ([]*platform.Task, int, error) {
tasks := []*platform.Task{
{
ID: 1,
Name: "task1",
Description: "A little Task",
OrganizationID: 1,
Organization: "test",
AuthorizationID: 0x100,
},
{
ID: 2,
Name: "task2",
OrganizationID: 2,
Organization: "test",
AuthorizationID: 0x200,
},
}
return tasks, len(tasks), nil
},
},
labelService: &mock.LabelService{
FindResourceLabelsFn: func(ctx context.Context, f platform.LabelMappingFilter) ([]*platform.Label, error) {
labels := []*platform.Label{
{
ID: platformtesting.MustIDBase16("fc3dc670a4be9b9a"),
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
},
}
return labels, nil
},
},
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks?limit=100"
},
"tasks": [
{
"links": {
"self": "/api/v2/tasks/0000000000000001",
"owners": "/api/v2/tasks/0000000000000001/owners",
"members": "/api/v2/tasks/0000000000000001/members",
"labels": "/api/v2/tasks/0000000000000001/labels",
"runs": "/api/v2/tasks/0000000000000001/runs",
"logs": "/api/v2/tasks/0000000000000001/logs"
},
"id": "0000000000000001",
"name": "task1",
"description": "A little Task",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000001",
"org": "test",
"status": "",
"authorizationID": "0000000000000100",
"flux": ""
},
{
"links": {
"self": "/api/v2/tasks/0000000000000002",
"owners": "/api/v2/tasks/0000000000000002/owners",
"members": "/api/v2/tasks/0000000000000002/members",
"labels": "/api/v2/tasks/0000000000000002/labels",
"runs": "/api/v2/tasks/0000000000000002/runs",
"logs": "/api/v2/tasks/0000000000000002/logs"
},
"id": "0000000000000002",
"name": "task2",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000002",
"org": "test",
"status": "",
"authorizationID": "0000000000000200",
"flux": ""
}
]
}`,
},
},
{
name: "get tasks by after and limit",
getParams: "after=0000000000000001&limit=1",
fields: fields{
taskService: &mock.TaskService{
FindTasksFn: func(ctx context.Context, f platform.TaskFilter) ([]*platform.Task, int, error) {
tasks := []*platform.Task{
{
ID: 2,
Name: "task2",
OrganizationID: 2,
Organization: "test",
AuthorizationID: 0x200,
},
}
return tasks, len(tasks), nil
},
},
labelService: &mock.LabelService{
FindResourceLabelsFn: func(ctx context.Context, f platform.LabelMappingFilter) ([]*platform.Label, error) {
labels := []*platform.Label{
{
ID: platformtesting.MustIDBase16("fc3dc670a4be9b9a"),
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
},
}
return labels, nil
},
},
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks?after=0000000000000001&limit=1",
"next": "/api/v2/tasks?after=0000000000000002&limit=1"
},
"tasks": [
{
"links": {
"self": "/api/v2/tasks/0000000000000002",
"owners": "/api/v2/tasks/0000000000000002/owners",
"members": "/api/v2/tasks/0000000000000002/members",
"labels": "/api/v2/tasks/0000000000000002/labels",
"runs": "/api/v2/tasks/0000000000000002/runs",
"logs": "/api/v2/tasks/0000000000000002/logs"
},
"id": "0000000000000002",
"name": "task2",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000002",
"org": "test",
"status": "",
"authorizationID": "0000000000000200",
"flux": ""
}
]
}`,
},
},
{
name: "get tasks by org name",
getParams: "org=test2",
fields: fields{
taskService: &mock.TaskService{
FindTasksFn: func(ctx context.Context, f platform.TaskFilter) ([]*platform.Task, int, error) {
tasks := []*platform.Task{
{
ID: 2,
Name: "task2",
OrganizationID: 2,
Organization: "test2",
AuthorizationID: 0x200,
},
}
return tasks, len(tasks), nil
},
},
labelService: &mock.LabelService{
FindResourceLabelsFn: func(ctx context.Context, f platform.LabelMappingFilter) ([]*platform.Label, error) {
labels := []*platform.Label{
{
ID: platformtesting.MustIDBase16("fc3dc670a4be9b9a"),
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
},
}
return labels, nil
},
},
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks?limit=100&org=test2"
},
"tasks": [
{
"links": {
"self": "/api/v2/tasks/0000000000000002",
"owners": "/api/v2/tasks/0000000000000002/owners",
"members": "/api/v2/tasks/0000000000000002/members",
"labels": "/api/v2/tasks/0000000000000002/labels",
"runs": "/api/v2/tasks/0000000000000002/runs",
"logs": "/api/v2/tasks/0000000000000002/logs"
},
"id": "0000000000000002",
"name": "task2",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000002",
"org": "test2",
"status": "",
"authorizationID": "0000000000000200",
"flux": ""
}
]
}`,
},
},
{
name: "get tasks by org name bad",
getParams: "org=non-existent-org",
fields: fields{
taskService: &mock.TaskService{
FindTasksFn: func(ctx context.Context, f platform.TaskFilter) ([]*platform.Task, int, error) {
tasks := []*platform.Task{
{
ID: 1,
Name: "task1",
OrganizationID: 1,
Organization: "test2",
AuthorizationID: 0x100,
},
{
ID: 2,
Name: "task2",
OrganizationID: 2,
Organization: "test2",
AuthorizationID: 0x200,
},
}
return tasks, len(tasks), nil
},
},
labelService: &mock.LabelService{
FindResourceLabelsFn: func(ctx context.Context, f platform.LabelMappingFilter) ([]*platform.Label, error) {
labels := []*platform.Label{
{
ID: platformtesting.MustIDBase16("fc3dc670a4be9b9a"),
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
},
}
return labels, nil
},
},
},
wants: wants{
statusCode: http.StatusBadRequest,
contentType: "application/json; charset=utf-8",
body: `{
"code": "invalid",
"error": {
"code": "not found",
"error": "org not found or unauthorized",
"message": "org non-existent-org not found or unauthorized"
},
"message": "failed to decode request"
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url?"+tt.getParams, nil)
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.HTTPErrorHandler = ErrorHandler(0)
taskBackend.TaskService = tt.fields.taskService
taskBackend.LabelService = tt.fields.labelService
h := NewTaskHandler(taskBackend)
h.handleGetTasks(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetTasks() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetTasks() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if tt.wants.body != "" {
if eq, diff, err := jsonEqual(string(body), tt.wants.body); err != nil {
t.Errorf("%q, handleGetTasks(). error unmarshaling json %v", tt.name, err)
} else if !eq {
t.Errorf("%q. handleGetTasks() = ***%s***", tt.name, diff)
}
}
})
}
}
func TestTaskHandler_handlePostTasks(t *testing.T) {
type args struct {
taskCreate platform.TaskCreate
}
type fields struct {
taskService platform.TaskService
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
args args
fields fields
wants wants
}{
{
name: "create task",
args: args{
taskCreate: platform.TaskCreate{
OrganizationID: 1,
Token: "mytoken",
Flux: "abc",
},
},
fields: fields{
taskService: &mock.TaskService{
CreateTaskFn: func(ctx context.Context, tc platform.TaskCreate) (*platform.Task, error) {
return &platform.Task{
ID: 1,
Name: "task1",
Description: "Brand New Task",
OrganizationID: 1,
Organization: "test",
AuthorizationID: 0x100,
Flux: "abc",
}, nil
},
},
},
wants: wants{
statusCode: http.StatusCreated,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001",
"owners": "/api/v2/tasks/0000000000000001/owners",
"members": "/api/v2/tasks/0000000000000001/members",
"labels": "/api/v2/tasks/0000000000000001/labels",
"runs": "/api/v2/tasks/0000000000000001/runs",
"logs": "/api/v2/tasks/0000000000000001/logs"
},
"id": "0000000000000001",
"name": "task1",
"description": "Brand New Task",
"labels": [],
"orgID": "0000000000000001",
"org": "test",
"status": "",
"authorizationID": "0000000000000100",
"flux": "abc"
}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b, err := json.Marshal(tt.args.taskCreate)
if err != nil {
t.Fatalf("failed to unmarshal task: %v", err)
}
r := httptest.NewRequest("POST", "http://any.url", bytes.NewReader(b))
ctx := pcontext.SetAuthorizer(context.TODO(), new(platform.Authorization))
r = r.WithContext(ctx)
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.HTTPErrorHandler = ErrorHandler(0)
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handlePostTask(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handlePostTask() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handlePostTask() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if tt.wants.body != "" {
if eq, diff, err := jsonEqual(string(body), tt.wants.body); err != nil {
t.Errorf("%q, handlePostTask(). error unmarshaling json %v", tt.name, err)
} else if !eq {
t.Errorf("%q. handlePostTask() = ***%s***", tt.name, diff)
}
}
})
}
}
func TestTaskHandler_handleGetRun(t *testing.T) {
type fields struct {
taskService platform.TaskService
}
type args struct {
taskID platform.ID
runID platform.ID
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "get a run by id",
fields: fields{
taskService: &mock.TaskService{
FindRunByIDFn: func(ctx context.Context, taskID platform.ID, runID platform.ID) (*platform.Run, error) {
run := platform.Run{
ID: runID,
TaskID: taskID,
Status: "success",
ScheduledFor: "2018-12-01T17:00:13Z",
StartedAt: "2018-12-01T17:00:03.155645Z",
FinishedAt: "2018-12-01T17:00:13.155645Z",
RequestedAt: "2018-12-01T17:00:13Z",
}
return &run, nil
},
},
},
args: args{
taskID: 1,
runID: 2,
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs/0000000000000002",
"task": "/api/v2/tasks/0000000000000001",
"retry": "/api/v2/tasks/0000000000000001/runs/0000000000000002/retry",
"logs": "/api/v2/tasks/0000000000000001/runs/0000000000000002/logs"
},
"id": "0000000000000002",
"taskID": "0000000000000001",
"status": "success",
"scheduledFor": "2018-12-01T17:00:13Z",
"startedAt": "2018-12-01T17:00:03.155645Z",
"finishedAt": "2018-12-01T17:00:13.155645Z",
"requestedAt": "2018-12-01T17:00:13Z"
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.Background(),
httprouter.ParamsKey,
httprouter.Params{
{
Key: "id",
Value: tt.args.taskID.String(),
},
{
Key: "rid",
Value: tt.args.runID.String(),
},
}))
r = r.WithContext(pcontext.SetAuthorizer(r.Context(), &platform.Authorization{Permissions: platform.OperPermissions()}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.HTTPErrorHandler = ErrorHandler(0)
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handleGetRun(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetRun() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetRun() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if tt.wants.body != "" {
if eq, diff, err := jsonEqual(string(body), tt.wants.body); err != nil {
t.Errorf("%q, handleGetRun(). error unmarshaling json %v", tt.name, err)
} else if !eq {
t.Errorf("%q. handleGetRun() = ***%s***", tt.name, diff)
}
}
})
}
}
func TestTaskHandler_handleGetRuns(t *testing.T) {
type fields struct {
taskService platform.TaskService
}
type args struct {
taskID platform.ID
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "get runs by task id",
fields: fields{
taskService: &mock.TaskService{
FindRunsFn: func(ctx context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
runs := []*platform.Run{
{
ID: platform.ID(2),
TaskID: f.Task,
Status: "success",
ScheduledFor: "2018-12-01T17:00:13Z",
StartedAt: "2018-12-01T17:00:03.155645Z",
FinishedAt: "2018-12-01T17:00:13.155645Z",
RequestedAt: "2018-12-01T17:00:13Z",
},
}
return runs, len(runs), nil
},
},
},
args: args{
taskID: 1,
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs",
"task": "/api/v2/tasks/0000000000000001"
},
"runs": [
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs/0000000000000002",
"task": "/api/v2/tasks/0000000000000001",
"retry": "/api/v2/tasks/0000000000000001/runs/0000000000000002/retry",
"logs": "/api/v2/tasks/0000000000000001/runs/0000000000000002/logs"
},
"id": "0000000000000002",
"taskID": "0000000000000001",
"status": "success",
"scheduledFor": "2018-12-01T17:00:13Z",
"startedAt": "2018-12-01T17:00:03.155645Z",
"finishedAt": "2018-12-01T17:00:13.155645Z",
"requestedAt": "2018-12-01T17:00:13Z"
}
]
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.Background(),
httprouter.ParamsKey,
httprouter.Params{
{
Key: "id",
Value: tt.args.taskID.String(),
},
}))
r = r.WithContext(pcontext.SetAuthorizer(r.Context(), &platform.Authorization{Permissions: platform.OperPermissions()}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend(t)
taskBackend.HTTPErrorHandler = ErrorHandler(0)
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handleGetRuns(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetRuns() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetRuns() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if tt.wants.body != "" {
if eq, diff, err := jsonEqual(string(body), tt.wants.body); err != nil {
t.Errorf("%q, handleGetRuns(). error unmarshaling json %v", tt.name, err)
} else if !eq {
t.Errorf("%q. handleGetRuns() = ***%s***", tt.name, diff)
}
}
})
}
}
func TestTaskHandler_NotFoundStatus(t *testing.T) {
// Ensure that the HTTP handlers return 404s for missing resources, and OKs for matching.
im := inmem.NewService()
taskBackend := NewMockTaskBackend(t)
taskBackend.HTTPErrorHandler = ErrorHandler(0)
h := NewTaskHandler(taskBackend)
h.UserResourceMappingService = im
h.LabelService = im
h.UserService = im
h.OrganizationService = im
o := platform.Organization{Name: "o"}
ctx := context.Background()
if err := h.OrganizationService.CreateOrganization(ctx, &o); err != nil {
t.Fatal(err)
}
// Create a session to associate with the contexts, so authorization checks pass.
authz := &platform.Authorization{Permissions: platform.OperPermissions()}
const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA)
var (
okTask = []interface{}{taskID}
okTaskRun = []interface{}{taskID, runID}
notFoundTask = [][]interface{}{
{taskID + 1},
}
notFoundTaskRun = [][]interface{}{
{taskID, runID + 1},
{taskID + 1, runID},
{taskID + 1, runID + 1},
}
)
tcs := []struct {
name string
svc *mock.TaskService
method string
body string
pathFmt string
okPathArgs []interface{}
notFoundPathArgs [][]interface{}
}{
{
name: "get task",
svc: &mock.TaskService{
FindTaskByIDFn: func(_ context.Context, id platform.ID) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, platform.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "update task",
svc: &mock.TaskService{
UpdateTaskFn: func(_ context.Context, id platform.ID, _ platform.TaskUpdate) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, platform.ErrTaskNotFound
},
},
method: http.MethodPatch,
body: `{"status": "active"}`,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "delete task",
svc: &mock.TaskService{
DeleteTaskFn: func(_ context.Context, id platform.ID) error {
if id == taskID {
return nil
}
return platform.ErrTaskNotFound
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get task logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if f.Task == taskID {
return nil, 0, nil
}
return nil, 0, platform.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/logs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if f.Task != taskID {
return nil, 0, platform.ErrTaskNotFound
}
if *f.Run != runID {
return nil, 0, platform.ErrNoRunsFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s/logs",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "get runs: task not found",
svc: &mock.TaskService{
FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
if f.Task != taskID {
return nil, 0, platform.ErrTaskNotFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get runs: task found but no runs found",
svc: &mock.TaskService{
FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
if f.Task != taskID {
return nil, 0, platform.ErrNoRunsFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "force run",
svc: &mock.TaskService{
ForceRunFn: func(_ context.Context, tid platform.ID, _ int64) (*platform.Run, error) {
if tid != taskID {
return nil, platform.ErrTaskNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
body: "{}",
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run",
svc: &mock.TaskService{
FindRunByIDFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, platform.ErrTaskNotFound
}
if rid != runID {
return nil, platform.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "retry run",
svc: &mock.TaskService{
RetryRunFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, platform.ErrTaskNotFound
}
if rid != runID {
return nil, platform.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
pathFmt: "/tasks/%s/runs/%s/retry",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "cancel run",
svc: &mock.TaskService{
CancelRunFn: func(_ context.Context, tid, rid platform.ID) error {
if tid != taskID {
return platform.ErrTaskNotFound
}
if rid != runID {
return platform.ErrRunNotFound
}
return nil
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
h.TaskService = tc.svc
okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...)
t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body)).WithContext(
pcontext.SetAuthorizer(context.Background(), authz),
)
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
t.Errorf("expected OK, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
t.Run("mismatched ID", func(t *testing.T) {
for _, nfa := range tc.notFoundPathArgs {
path := fmt.Sprintf(tc.pathFmt, nfa...)
t.Run(tc.method+" "+path, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body)).WithContext(
pcontext.SetAuthorizer(context.Background(), authz),
)
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusNotFound {
t.Errorf("expected Not Found, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
}
})
})
}
}
func TestService_handlePostTaskLabel(t *testing.T) {
type fields struct {
LabelService platform.LabelService
}
type args struct {
labelMapping *platform.LabelMapping
taskID platform.ID
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "add label to task",
fields: fields{
LabelService: &mock.LabelService{
FindLabelByIDFn: func(ctx context.Context, id platform.ID) (*platform.Label, error) {
return &platform.Label{
ID: 1,
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
}, nil
},
CreateLabelMappingFn: func(ctx context.Context, m *platform.LabelMapping) error { return nil },
},
},
args: args{
labelMapping: &platform.LabelMapping{
ResourceID: 100,
LabelID: 1,
},
taskID: 100,
},
wants: wants{
statusCode: http.StatusCreated,
contentType: "application/json; charset=utf-8",
body: `
{
"label": {
"id": "0000000000000001",
"name": "label",
"properties": {
"color": "fff000"
}
},
"links": {
"self": "/api/v2/labels/0000000000000001"
}
}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
taskBE := NewMockTaskBackend(t)
taskBE.LabelService = tt.fields.LabelService
h := NewTaskHandler(taskBE)
b, err := json.Marshal(tt.args.labelMapping)
if err != nil {
t.Fatalf("failed to unmarshal label mapping: %v", err)
}
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/labels", tt.args.taskID)
r := httptest.NewRequest("POST", url, bytes.NewReader(b))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("got %v, want %v", res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("got %v, want %v", content, tt.wants.contentType)
}
if tt.wants.body != "" {
if eq, diff, err := jsonEqual(string(body), tt.wants.body); err != nil {
t.Errorf("%q, handlePostTaskLabel(). error unmarshaling json %v", tt.name, err)
} else if !eq {
t.Errorf("%q. handlePostTaskLabel() = ***%s***", tt.name, diff)
}
}
})
}
}
// Test that org name to org ID translation happens properly in the HTTP layer.
// Regression test for https://github.com/influxdata/influxdb/issues/12089.
func TestTaskHandler_CreateTaskWithOrgName(t *testing.T) {
i := inmem.NewService()
ctx := context.Background()
// Set up user and org.
u := &platform.User{Name: "u"}
if err := i.CreateUser(ctx, u); err != nil {
t.Fatal(err)
}
o := &platform.Organization{Name: "o"}
if err := i.CreateOrganization(ctx, o); err != nil {
t.Fatal(err)
}
// Source and destination buckets for use in task.
bSrc := platform.Bucket{OrgID: o.ID, Name: "b-src"}
if err := i.CreateBucket(ctx, &bSrc); err != nil {
t.Fatal(err)
}
bDst := platform.Bucket{OrgID: o.ID, Name: "b-dst"}
if err := i.CreateBucket(ctx, &bDst); err != nil {
t.Fatal(err)
}
authz := platform.Authorization{OrgID: o.ID, UserID: u.ID, Permissions: platform.OperPermissions()}
if err := i.CreateAuthorization(ctx, &authz); err != nil {
t.Fatal(err)
}
ts := &mock.TaskService{
CreateTaskFn: func(_ context.Context, tc platform.TaskCreate) (*platform.Task, error) {
if tc.OrganizationID != o.ID {
t.Fatalf("expected task to be created with org ID %s, got %s", o.ID, tc.OrganizationID)
}
if tc.Token != authz.Token {
t.Fatalf("expected task to be created with previous token %s, got %s", authz.Token, tc.Token)
}
return &platform.Task{ID: 9, OrganizationID: o.ID, AuthorizationID: authz.ID, Name: "x", Flux: tc.Flux}, nil
},
}
h := NewTaskHandler(&TaskBackend{
Logger: zaptest.NewLogger(t),
TaskService: ts,
AuthorizationService: i,
OrganizationService: i,
UserResourceMappingService: i,
LabelService: i,
UserService: i,
BucketService: i,
})
const script = `option task = {name:"x", every:1m} from(bucket:"b-src") |> range(start:-1m) |> to(bucket:"b-dst", org:"o")`
url := "http://localhost:9999/api/v2/tasks"
b, err := json.Marshal(platform.TaskCreate{
Flux: script,
Organization: o.Name,
Token: authz.Token,
})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(
pcontext.SetAuthorizer(ctx, &authz),
)
w := httptest.NewRecorder()
h.handlePostTask(w, r)
res := w.Result()
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusCreated {
t.Logf("response body: %s", body)
t.Fatalf("expected status created, got %v", res.StatusCode)
}
// The task should have been created with a valid token.
var createdTask platform.Task
if err := json.Unmarshal([]byte(body), &createdTask); err != nil {
t.Fatal(err)
}
if createdTask.Flux != script {
t.Fatalf("Unexpected script returned:\n got: %s\nwant: %s", createdTask.Flux, script)
}
}
func TestTaskHandler_Sessions(t *testing.T) {
// Common setup to get a working base for using tasks.
i := inmem.NewService()
ctx := context.Background()
// Set up user and org.
u := &platform.User{Name: "u"}
if err := i.CreateUser(ctx, u); err != nil {
t.Fatal(err)
}
o := &platform.Organization{Name: "o"}
if err := i.CreateOrganization(ctx, o); err != nil {
t.Fatal(err)
}
// Map user to org.
if err := i.CreateUserResourceMapping(ctx, &platform.UserResourceMapping{
ResourceType: platform.OrgsResourceType,
ResourceID: o.ID,
UserID: u.ID,
UserType: platform.Owner,
}); err != nil {
t.Fatal(err)
}
// Source and destination buckets for use in task.
bSrc := platform.Bucket{OrgID: o.ID, Name: "b-src"}
if err := i.CreateBucket(ctx, &bSrc); err != nil {
t.Fatal(err)
}
bDst := platform.Bucket{OrgID: o.ID, Name: "b-dst"}
if err := i.CreateBucket(ctx, &bDst); err != nil {
t.Fatal(err)
}
sessionAllPermsCtx := pcontext.SetAuthorizer(context.Background(), &platform.Session{
UserID: u.ID,
Permissions: platform.OperPermissions(),
ExpiresAt: time.Now().Add(24 * time.Hour),
})
sessionNoPermsCtx := pcontext.SetAuthorizer(context.Background(), &platform.Session{
UserID: u.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
newHandler := func(t *testing.T, ts *mock.TaskService) *TaskHandler {
return NewTaskHandler(&TaskBackend{
HTTPErrorHandler: ErrorHandler(0),
Logger: zaptest.NewLogger(t),
TaskService: ts,
AuthorizationService: i,
OrganizationService: i,
UserResourceMappingService: i,
LabelService: i,
UserService: i,
BucketService: i,
})
}
t.Run("creating a task from a session", func(t *testing.T) {
taskID := platform.ID(9)
var createdTasks []platform.TaskCreate
ts := &mock.TaskService{
CreateTaskFn: func(_ context.Context, tc platform.TaskCreate) (*platform.Task, error) {
createdTasks = append(createdTasks, tc)
// Task with fake IDs so it can be serialized.
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: 999, Name: "x"}, nil
},
// Needed due to task authorization bootstrapping.
UpdateTaskFn: func(ctx context.Context, id platform.ID, tu platform.TaskUpdate) (*platform.Task, error) {
authz, err := i.FindAuthorizationByToken(ctx, tu.Token)
if err != nil {
t.Fatal(err)
}
return &platform.Task{ID: taskID, OrganizationID: 99, AuthorizationID: authz.ID, Name: "x"}, nil
},
}
h := newHandler(t, ts)
url := "http://localhost:9999/api/v2/tasks"
b, err := json.Marshal(platform.TaskCreate{
Flux: `option task = {name:"x", every:1m} from(bucket:"b-src") |> range(start:-1m) |> to(bucket:"b-dst", org:"o")`,
OrganizationID: o.ID,
})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(sessionAllPermsCtx)
w := httptest.NewRecorder()
h.handlePostTask(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusCreated {
t.Logf("response body: %s", body)
t.Fatalf("expected status created, got %v", res.StatusCode)
}
if len(createdTasks) != 1 {
t.Fatalf("didn't create task; got %#v", createdTasks)
}
// The task should have been created with a valid token.
var createdTask platform.Task
if err := json.Unmarshal([]byte(body), &createdTask); err != nil {
t.Fatal(err)
}
authz, err := i.FindAuthorizationByID(ctx, createdTask.AuthorizationID)
if err != nil {
t.Fatal(err)
}
if authz.OrgID != o.ID {
t.Fatalf("expected authorization to have org ID %v, got %v", o.ID, authz.OrgID)
}
if authz.UserID != u.ID {
t.Fatalf("expected authorization to have user ID %v, got %v", u.ID, authz.UserID)
}
const expDesc = `auto-generated authorization for task "x"`
if authz.Description != expDesc {
t.Fatalf("expected authorization to be created with description %q, got %q", expDesc, authz.Description)
}
// The authorization should be allowed to read and write the target buckets,
// and it should be allowed to read its task.
if !authz.Allowed(platform.Permission{
Action: platform.ReadAction,
Resource: platform.Resource{
Type: platform.BucketsResourceType,
OrgID: &o.ID,
ID: &bSrc.ID,
},
}) {
t.Logf("WARNING: permissions on `from` buckets not yet accessible: update test after https://github.com/influxdata/flux/issues/114 is fixed.")
}
if !authz.Allowed(platform.Permission{
Action: platform.WriteAction,
Resource: platform.Resource{
Type: platform.BucketsResourceType,
OrgID: &o.ID,
ID: &bDst.ID,
},
}) {
t.Fatalf("expected authorization to be allowed write access to destination bucket, but it wasn't allowed")
}
if !authz.Allowed(platform.Permission{
Action: platform.ReadAction,
Resource: platform.Resource{
Type: platform.TasksResourceType,
OrgID: &o.ID,
ID: &taskID,
},
}) {
t.Fatalf("expected authorization to be allowed to read its task, but it wasn't allowed")
}
// Session without permissions should not be allowed to create task.
r = httptest.NewRequest("POST", url, bytes.NewReader(b)).WithContext(sessionNoPermsCtx)
w = httptest.NewRecorder()
h.handlePostTask(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized or forbidden, got %v", res.StatusCode)
}
})
t.Run("get runs for a task", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findRunsCtx context.Context
ts := &mock.TaskService{
FindRunsFn: func(ctx context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
findRunsCtx = ctx
if f.Task != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, f.Task)
}
return []*platform.Run{
{ID: runID, TaskID: taskID},
}, 1, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, platform.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs", taskID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{{Key: "id", Value: taskID.String()}})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetRuns(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindRuns must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findRunsCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("get single run for a task", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findRunByIDCtx context.Context
ts := &mock.TaskService{
FindRunByIDFn: func(ctx context.Context, tid, rid platform.ID) (*platform.Run, error) {
findRunByIDCtx = ctx
if tid != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, tid)
}
if rid != runID {
t.Fatalf("expected run ID %v, got %v", runID, rid)
}
return &platform.Run{ID: runID, TaskID: taskID}, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, platform.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetRun(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindRunByID must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findRunByIDCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("get logs for a run", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var findLogsCtx context.Context
ts := &mock.TaskService{
FindLogsFn: func(ctx context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
findLogsCtx = ctx
if f.Task != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, f.Task)
}
if *f.Run != runID {
t.Fatalf("expected run ID %v, got %v", runID, *f.Run)
}
line := platform.Log{Time: "time", Message: "a log line"}
return []*platform.Log{&line}, 1, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, platform.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s/logs", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleGetLogs(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.FindLogs must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(findLogsCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("GET", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
t.Run("retry a run", func(t *testing.T) {
// Unique authorization to associate with our fake task.
taskAuth := &platform.Authorization{OrgID: o.ID, UserID: u.ID}
if err := i.CreateAuthorization(ctx, taskAuth); err != nil {
t.Fatal(err)
}
const taskID = platform.ID(12345)
const runID = platform.ID(9876)
var retryRunCtx context.Context
ts := &mock.TaskService{
RetryRunFn: func(ctx context.Context, tid, rid platform.ID) (*platform.Run, error) {
retryRunCtx = ctx
if tid != taskID {
t.Fatalf("expected task ID %v, got %v", taskID, tid)
}
if rid != runID {
t.Fatalf("expected run ID %v, got %v", runID, rid)
}
return &platform.Run{ID: 10 * runID, TaskID: taskID}, nil
},
FindTaskByIDFn: func(ctx context.Context, id platform.ID) (*platform.Task, error) {
if id != taskID {
return nil, platform.ErrTaskNotFound
}
return &platform.Task{
ID: taskID,
OrganizationID: o.ID,
AuthorizationID: taskAuth.ID,
}, nil
},
}
h := newHandler(t, ts)
url := fmt.Sprintf("http://localhost:9999/api/v2/tasks/%s/runs/%s/retry", taskID, runID)
valCtx := context.WithValue(sessionAllPermsCtx, httprouter.ParamsKey, httprouter.Params{
{Key: "id", Value: taskID.String()},
{Key: "rid", Value: runID.String()},
})
r := httptest.NewRequest("POST", url, nil).WithContext(valCtx)
w := httptest.NewRecorder()
h.handleRetryRun(w, r)
res := w.Result()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Logf("response body: %s", body)
t.Fatalf("expected status OK, got %v", res.StatusCode)
}
// The context passed to TaskService.RetryRun must be a valid authorization (not a session).
authr, err := pcontext.GetAuthorizer(retryRunCtx)
if err != nil {
t.Fatal(err)
}
if authr.Kind() != platform.AuthorizationKind {
t.Fatalf("expected context's authorizer to be of kind %q, got %q", platform.AuthorizationKind, authr.Kind())
}
if authr.Identifier() != taskAuth.ID {
t.Fatalf("expected context's authorizer ID to be %v, got %v", taskAuth.ID, authr.Identifier())
}
// Other user without permissions on the task or authorization should be disallowed.
otherUser := &platform.User{Name: "other-" + t.Name()}
if err := i.CreateUser(ctx, otherUser); err != nil {
t.Fatal(err)
}
valCtx = pcontext.SetAuthorizer(valCtx, &platform.Session{
UserID: otherUser.ID,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
r = httptest.NewRequest("POST", url, nil).WithContext(valCtx)
w = httptest.NewRecorder()
h.handleGetRuns(w, r)
res = w.Result()
body, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Logf("response body: %s", body)
t.Fatalf("expected status unauthorized, got %v", res.StatusCode)
}
})
}