influxdb/http/task_service_test.go

838 lines
22 KiB
Go

package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"go.uber.org/zap"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
platform "github.com/influxdata/influxdb"
pcontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/inmem"
"github.com/influxdata/influxdb/mock"
_ "github.com/influxdata/influxdb/query/builtin"
"github.com/influxdata/influxdb/task/backend"
platformtesting "github.com/influxdata/influxdb/testing"
"github.com/julienschmidt/httprouter"
)
// NewMockTaskBackend returns a TaskBackend with mock services.
func NewMockTaskBackend() *TaskBackend {
return &TaskBackend{
Logger: zap.NewNop().With(zap.String("handler", "task")),
AuthorizationService: mock.NewAuthorizationService(),
TaskService: &mock.TaskService{},
OrganizationService: &mock.OrganizationService{
FindOrganizationByIDF: func(ctx context.Context, id platform.ID) (*platform.Organization, error) {
return &platform.Organization{ID: id, Name: "test"}, nil
},
FindOrganizationF: func(ctx context.Context, filter platform.OrganizationFilter) (*platform.Organization, error) {
org := &platform.Organization{}
if filter.Name != nil {
org.Name = *filter.Name
}
if filter.ID != nil {
org.ID = *filter.ID
}
return org, nil
},
},
UserResourceMappingService: mock.NewUserResourceMappingService(),
LabelService: mock.NewLabelService(),
UserService: mock.NewUserService(),
}
}
func TestTaskHandler_handleGetTasks(t *testing.T) {
type fields struct {
taskService platform.TaskService
labelService platform.LabelService
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
wants wants
}{
{
name: "get tasks",
fields: fields{
taskService: &mock.TaskService{
FindTasksFn: func(ctx context.Context, f platform.TaskFilter) ([]*platform.Task, int, error) {
tasks := []*platform.Task{
{
ID: 1,
Name: "task1",
OrganizationID: 1,
Owner: platform.User{ID: 1, Name: "user1"},
},
{
ID: 2,
Name: "task2",
OrganizationID: 2,
Owner: platform.User{ID: 2, Name: "user2"},
},
}
return tasks, len(tasks), nil
},
},
labelService: &mock.LabelService{
FindResourceLabelsFn: func(ctx context.Context, f platform.LabelMappingFilter) ([]*platform.Label, error) {
labels := []*platform.Label{
{
ID: platformtesting.MustIDBase16("fc3dc670a4be9b9a"),
Name: "label",
Properties: map[string]string{
"color": "fff000",
},
},
}
return labels, nil
},
},
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks"
},
"tasks": [
{
"links": {
"self": "/api/v2/tasks/0000000000000001",
"owners": "/api/v2/tasks/0000000000000001/owners",
"members": "/api/v2/tasks/0000000000000001/members",
"labels": "/api/v2/tasks/0000000000000001/labels",
"runs": "/api/v2/tasks/0000000000000001/runs",
"logs": "/api/v2/tasks/0000000000000001/logs"
},
"id": "0000000000000001",
"name": "task1",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000001",
"org": "test",
"status": "",
"flux": ""
},
{
"links": {
"self": "/api/v2/tasks/0000000000000002",
"owners": "/api/v2/tasks/0000000000000002/owners",
"members": "/api/v2/tasks/0000000000000002/members",
"labels": "/api/v2/tasks/0000000000000002/labels",
"runs": "/api/v2/tasks/0000000000000002/runs",
"logs": "/api/v2/tasks/0000000000000002/logs"
},
"id": "0000000000000002",
"name": "task2",
"labels": [
{
"id": "fc3dc670a4be9b9a",
"name": "label",
"properties": {
"color": "fff000"
}
}
],
"orgID": "0000000000000002",
"org": "test",
"status": "",
"flux": ""
}
]
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend()
taskBackend.TaskService = tt.fields.taskService
taskBackend.LabelService = tt.fields.labelService
h := NewTaskHandler(taskBackend)
h.handleGetTasks(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetTasks() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetTasks() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if eq, diff, _ := jsonEqual(string(body), tt.wants.body); tt.wants.body != "" && !eq {
t.Errorf("%q. handleGetTasks() = ***%s***", tt.name, diff)
}
})
}
}
func TestTaskHandler_handlePostTasks(t *testing.T) {
type args struct {
task platform.Task
}
type fields struct {
taskService platform.TaskService
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
args args
fields fields
wants wants
}{
{
name: "create task",
args: args{
task: platform.Task{
Name: "task1",
OrganizationID: 1,
Owner: platform.User{
ID: 1,
Name: "user1",
},
},
},
fields: fields{
taskService: &mock.TaskService{
CreateTaskFn: func(ctx context.Context, t *platform.Task) error {
t.ID = 1
return nil
},
},
},
wants: wants{
statusCode: http.StatusCreated,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001",
"owners": "/api/v2/tasks/0000000000000001/owners",
"members": "/api/v2/tasks/0000000000000001/members",
"labels": "/api/v2/tasks/0000000000000001/labels",
"runs": "/api/v2/tasks/0000000000000001/runs",
"logs": "/api/v2/tasks/0000000000000001/logs"
},
"id": "0000000000000001",
"name": "task1",
"labels": [],
"orgID": "0000000000000001",
"org": "test",
"status": "",
"flux": ""
}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b, err := json.Marshal(tt.args.task)
if err != nil {
t.Fatalf("failed to unmarshal task: %v", err)
}
r := httptest.NewRequest("POST", "http://any.url", bytes.NewReader(b))
ctx := pcontext.SetAuthorizer(context.TODO(), new(platform.Authorization))
r = r.WithContext(ctx)
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend()
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handlePostTask(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handlePostTask() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handlePostTask() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if eq, diff, _ := jsonEqual(string(body), tt.wants.body); tt.wants.body != "" && !eq {
t.Errorf("%q. handlePostTask() = ***%s***", tt.name, diff)
}
})
}
}
func TestTaskHandler_handleGetRun(t *testing.T) {
type fields struct {
taskService platform.TaskService
}
type args struct {
taskID platform.ID
runID platform.ID
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "get a run by id",
fields: fields{
taskService: &mock.TaskService{
FindRunByIDFn: func(ctx context.Context, taskID platform.ID, runID platform.ID) (*platform.Run, error) {
run := platform.Run{
ID: runID,
TaskID: taskID,
Status: "success",
ScheduledFor: "2018-12-01T17:00:13Z",
StartedAt: "2018-12-01T17:00:03.155645Z",
FinishedAt: "2018-12-01T17:00:13.155645Z",
RequestedAt: "2018-12-01T17:00:13Z",
}
return &run, nil
},
},
},
args: args{
taskID: 1,
runID: 2,
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs/0000000000000002",
"task": "/api/v2/tasks/0000000000000001",
"retry": "/api/v2/tasks/0000000000000001/runs/0000000000000002/retry",
"logs": "/api/v2/tasks/0000000000000001/runs/0000000000000002/logs"
},
"id": "0000000000000002",
"taskID": "0000000000000001",
"status": "success",
"scheduledFor": "2018-12-01T17:00:13Z",
"startedAt": "2018-12-01T17:00:03.155645Z",
"finishedAt": "2018-12-01T17:00:13.155645Z",
"requestedAt": "2018-12-01T17:00:13Z",
"log": ""
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.TODO(),
httprouter.ParamsKey,
httprouter.Params{
{
Key: "id",
Value: tt.args.taskID.String(),
},
{
Key: "rid",
Value: tt.args.runID.String(),
},
}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend()
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handleGetRun(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetRun() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetRun() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if eq, diff, _ := jsonEqual(string(body), tt.wants.body); tt.wants.body != "" && !eq {
t.Errorf("%q. handleGetRun() = ***%s***", tt.name, diff)
}
})
}
}
func TestTaskHandler_handleGetRuns(t *testing.T) {
type fields struct {
taskService platform.TaskService
}
type args struct {
taskID platform.ID
}
type wants struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "get runs by task id",
fields: fields{
taskService: &mock.TaskService{
FindRunsFn: func(ctx context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
runs := []*platform.Run{
{
ID: platform.ID(2),
TaskID: *f.Task,
Status: "success",
ScheduledFor: "2018-12-01T17:00:13Z",
StartedAt: "2018-12-01T17:00:03.155645Z",
FinishedAt: "2018-12-01T17:00:13.155645Z",
RequestedAt: "2018-12-01T17:00:13Z",
},
}
return runs, len(runs), nil
},
},
},
args: args{
taskID: 1,
},
wants: wants{
statusCode: http.StatusOK,
contentType: "application/json; charset=utf-8",
body: `
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs",
"task": "/api/v2/tasks/0000000000000001"
},
"runs": [
{
"links": {
"self": "/api/v2/tasks/0000000000000001/runs/0000000000000002",
"task": "/api/v2/tasks/0000000000000001",
"retry": "/api/v2/tasks/0000000000000001/runs/0000000000000002/retry",
"logs": "/api/v2/tasks/0000000000000001/runs/0000000000000002/logs"
},
"id": "0000000000000002",
"taskID": "0000000000000001",
"status": "success",
"scheduledFor": "2018-12-01T17:00:13Z",
"startedAt": "2018-12-01T17:00:03.155645Z",
"finishedAt": "2018-12-01T17:00:13.155645Z",
"requestedAt": "2018-12-01T17:00:13Z",
"log": ""
}
]
}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
r = r.WithContext(context.WithValue(
context.TODO(),
httprouter.ParamsKey,
httprouter.Params{
{
Key: "id",
Value: tt.args.taskID.String(),
},
}))
w := httptest.NewRecorder()
taskBackend := NewMockTaskBackend()
taskBackend.TaskService = tt.fields.taskService
h := NewTaskHandler(taskBackend)
h.handleGetRuns(w, r)
res := w.Result()
content := res.Header.Get("Content-Type")
body, _ := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.wants.statusCode {
t.Errorf("%q. handleGetRuns() = %v, want %v", tt.name, res.StatusCode, tt.wants.statusCode)
}
if tt.wants.contentType != "" && content != tt.wants.contentType {
t.Errorf("%q. handleGetRuns() = %v, want %v", tt.name, content, tt.wants.contentType)
}
if eq, diff, _ := jsonEqual(string(body), tt.wants.body); tt.wants.body != "" && !eq {
t.Errorf("%q. handleGetRuns() = ***%s***", tt.name, diff)
}
})
}
}
func TestTaskHandler_NotFoundStatus(t *testing.T) {
// Ensure that the HTTP handlers return 404s for missing resources, and OKs for matching.
im := inmem.NewService()
taskBackend := NewMockTaskBackend()
h := NewTaskHandler(taskBackend)
h.UserResourceMappingService = im
h.LabelService = im
h.UserService = im
h.OrganizationService = im
o := platform.Organization{Name: "o"}
ctx := context.Background()
if err := h.OrganizationService.CreateOrganization(ctx, &o); err != nil {
t.Fatal(err)
}
const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA)
var (
okTask = []interface{}{taskID}
okTaskRun = []interface{}{taskID, runID}
notFoundTask = [][]interface{}{
{taskID + 1},
}
notFoundTaskRun = [][]interface{}{
{taskID, runID + 1},
{taskID + 1, runID},
{taskID + 1, runID + 1},
}
)
tcs := []struct {
name string
svc *mock.TaskService
method string
body string
pathFmt string
okPathArgs []interface{}
notFoundPathArgs [][]interface{}
}{
{
name: "get task",
svc: &mock.TaskService{
FindTaskByIDFn: func(_ context.Context, id platform.ID) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, backend.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "update task",
svc: &mock.TaskService{
UpdateTaskFn: func(_ context.Context, id platform.ID, _ platform.TaskUpdate) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, backend.ErrTaskNotFound
},
},
method: http.MethodPatch,
body: "{}",
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "delete task",
svc: &mock.TaskService{
DeleteTaskFn: func(_ context.Context, id platform.ID) error {
if id == taskID {
return nil
}
return backend.ErrTaskNotFound
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get task logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if *f.Task == taskID {
return nil, 0, nil
}
return nil, 0, backend.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/logs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if *f.Task != taskID {
return nil, 0, backend.ErrTaskNotFound
}
if *f.Run != runID {
return nil, 0, backend.ErrRunNotFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s/logs",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "get runs",
svc: &mock.TaskService{
FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
if *f.Task != taskID {
return nil, 0, backend.ErrTaskNotFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "force run",
svc: &mock.TaskService{
ForceRunFn: func(_ context.Context, tid platform.ID, _ int64) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
body: "{}",
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run",
svc: &mock.TaskService{
FindRunByIDFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
if rid != runID {
return nil, backend.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "retry run",
svc: &mock.TaskService{
RetryRunFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
if rid != runID {
return nil, backend.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
pathFmt: "/tasks/%s/runs/%s/retry",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "cancel run",
svc: &mock.TaskService{
CancelRunFn: func(_ context.Context, tid, rid platform.ID) error {
if tid != taskID {
return backend.ErrTaskNotFound
}
if rid != runID {
return backend.ErrRunNotFound
}
return nil
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
h.TaskService = tc.svc
okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...)
t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body))
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
t.Errorf("expected OK, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
t.Run("mismatched ID", func(t *testing.T) {
for _, nfa := range tc.notFoundPathArgs {
path := fmt.Sprintf(tc.pathFmt, nfa...)
t.Run(tc.method+" "+path, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body))
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusNotFound {
t.Errorf("expected Not Found, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
}
})
})
}
}
func TestTaskUserResourceMap(t *testing.T) {
task := platform.Task{
Name: "task1",
OrganizationID: 1,
}
b, err := json.Marshal(task)
if err != nil {
t.Fatalf("failed to unmarshal task: %v", err)
}
r := httptest.NewRequest("POST", "http://any.url/v1", bytes.NewReader(b))
ctx := pcontext.SetAuthorizer(context.Background(), &platform.Authorization{UserID: 2})
r = r.WithContext(ctx)
w := httptest.NewRecorder()
var created *platform.UserResourceMapping
var deletedUser platform.ID
var deletedResource platform.ID
urms := &mock.UserResourceMappingService{
CreateMappingFn: func(_ context.Context, urm *platform.UserResourceMapping) error { created = urm; return nil },
DeleteMappingFn: func(_ context.Context, rid platform.ID, uid platform.ID) error {
deletedUser = uid
deletedResource = rid
return nil
},
FindMappingsFn: func(context.Context, platform.UserResourceMappingFilter) ([]*platform.UserResourceMapping, int, error) {
return []*platform.UserResourceMapping{created}, 1, nil
},
}
taskBackend := NewMockTaskBackend()
taskBackend.UserResourceMappingService = urms
h := NewTaskHandler(taskBackend)
taskID := platform.ID(1)
h.TaskService = &mock.TaskService{
CreateTaskFn: func(ctx context.Context, t *platform.Task) error {
t.ID = taskID
return nil
},
DeleteTaskFn: func(ctx context.Context, id platform.ID) error {
return nil
},
}
h.handlePostTask(w, r)
r = httptest.NewRequest("DELETE", "http://any.url/api/v2/tasks/"+taskID.String(), nil)
h.ServeHTTP(w, r)
if created.UserID != deletedUser {
t.Fatalf("deleted user (%s) doesn't match created user (%s)", deletedUser, created.UserID)
}
if created.ResourceID != deletedResource {
t.Fatalf("deleted resource (%s) doesn't match created resource (%s)", deletedResource, created.ResourceID)
}
}