From 3f0e40812e0e4a3bec260341a456ad6850f72723 Mon Sep 17 00:00:00 2001 From: Mark Rushakoff Date: Fri, 25 Jan 2019 15:17:30 -0800 Subject: [PATCH] fix(http): return 404 when task or run is not found For an operation that looks up a task or a run, when that operation fails, only set the status to 404 if that operation explicitly returns ErrTaskNotFound or ErrRunNotFound. It's possible that the operation could fail for a reason other than the ID being invalid: for example, if there was an IO error preventing the lookup from succeeding. Harden that behavior with tests for the task handler. Closes #11589. --- http/task_service.go | 66 +++++---- http/task_service_test.go | 275 +++++++++++++++++++++++++++++++++++++- 2 files changed, 309 insertions(+), 32 deletions(-) diff --git a/http/task_service.go b/http/task_service.go index 41cf73034e..b9b0445809 100644 --- a/http/task_service.go +++ b/http/task_service.go @@ -462,10 +462,13 @@ func (h *TaskHandler) handleUpdateTask(w http.ResponseWriter, r *http.Request) { } task, err := h.TaskService.UpdateTask(ctx, req.TaskID, req.Update) if err != nil { - err = &platform.Error{ + err := &platform.Error{ Err: err, Msg: "failed to update task", } + if err.Err == backend.ErrTaskNotFound { + err.Code = platform.ENotFound + } EncodeError(ctx, err, w) return } @@ -541,10 +544,13 @@ func (h *TaskHandler) handleDeleteTask(w http.ResponseWriter, r *http.Request) { } if err := h.TaskService.DeleteTask(ctx, req.TaskID); err != nil { - err = &platform.Error{ + err := &platform.Error{ Err: err, Msg: "failed to delete task", } + if err.Err == backend.ErrTaskNotFound { + err.Code = platform.ENotFound + } EncodeError(ctx, err, w) return } @@ -592,10 +598,13 @@ func (h *TaskHandler) handleGetLogs(w http.ResponseWriter, r *http.Request) { logs, _, err := h.TaskService.FindLogs(ctx, req.filter) if err != nil { - err = &platform.Error{ + err := &platform.Error{ Err: err, Msg: "failed to find task logs", } + if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound { + err.Code = platform.ENotFound + } EncodeError(ctx, err, w) return } @@ -671,10 +680,12 @@ func (h *TaskHandler) handleGetRuns(w http.ResponseWriter, r *http.Request) { runs, _, err := h.TaskService.FindRuns(ctx, req.filter) if err != nil { - err = &platform.Error{ - Err: err, - Code: platform.EInvalid, - Msg: "failed to find runs", + err := &platform.Error{ + Err: err, + Msg: "failed to find runs", + } + if err.Err == backend.ErrTaskNotFound { + err.Code = platform.ENotFound } EncodeError(ctx, err, w) return @@ -792,12 +803,12 @@ func (h *TaskHandler) handleForceRun(w http.ResponseWriter, r *http.Request) { run, err := h.TaskService.ForceRun(ctx, req.TaskID, req.Timestamp) if err != nil { - if err == backend.ErrRunNotFound { - err = &platform.Error{ - Code: platform.ENotFound, - Msg: "failed to force run", - Err: err, - } + err := &platform.Error{ + Err: err, + Msg: "failed to force run", + } + if err.Err == backend.ErrTaskNotFound { + err.Code = platform.ENotFound } EncodeError(ctx, err, w) return @@ -868,12 +879,12 @@ func (h *TaskHandler) handleGetRun(w http.ResponseWriter, r *http.Request) { run, err := h.TaskService.FindRunByID(ctx, req.TaskID, req.RunID) if err != nil { - if err == backend.ErrRunNotFound { - err = &platform.Error{ - Err: err, - Msg: "failed to find run", - Code: platform.ENotFound, - } + err := &platform.Error{ + Err: err, + Msg: "failed to find run", + } + if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound { + err.Code = platform.ENotFound } EncodeError(ctx, err, w) return @@ -974,10 +985,13 @@ func (h *TaskHandler) handleCancelRun(w http.ResponseWriter, r *http.Request) { err = h.TaskService.CancelRun(ctx, req.TaskID, req.RunID) if err != nil { - err = &platform.Error{ + err := &platform.Error{ Err: err, Msg: "failed to cancel run", } + if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound { + err.Code = platform.ENotFound + } EncodeError(ctx, err, w) return } @@ -999,12 +1013,12 @@ func (h *TaskHandler) handleRetryRun(w http.ResponseWriter, r *http.Request) { run, err := h.TaskService.RetryRun(ctx, req.TaskID, req.RunID) if err != nil { - if err == backend.ErrRunNotFound { - err = &platform.Error{ - Code: platform.ENotFound, - Msg: "failed to retry run", - Err: err, - } + err := &platform.Error{ + Err: err, + Msg: "failed to retry run", + } + if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound { + err.Code = platform.ENotFound } EncodeError(ctx, err, w) return diff --git a/http/task_service_test.go b/http/task_service_test.go index 0ba937b01a..56f128d8bd 100644 --- a/http/task_service_test.go +++ b/http/task_service_test.go @@ -4,19 +4,22 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io/ioutil" "net/http" "net/http/httptest" - "os" + "strings" "testing" platform "github.com/influxdata/influxdb" pcontext "github.com/influxdata/influxdb/context" - "github.com/influxdata/influxdb/logger" + "github.com/influxdata/influxdb/inmem" "github.com/influxdata/influxdb/mock" _ "github.com/influxdata/influxdb/query/builtin" + "github.com/influxdata/influxdb/task/backend" platformtesting "github.com/influxdata/influxdb/testing" "github.com/julienschmidt/httprouter" + "go.uber.org/zap/zaptest" ) func mockOrgService() platform.OrganizationService { @@ -161,7 +164,7 @@ func TestTaskHandler_handleGetTasks(t *testing.T) { r := httptest.NewRequest("GET", "http://any.url", nil) w := httptest.NewRecorder() - h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService()) + h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService()) h.OrganizationService = mockOrgService() h.TaskService = tt.fields.taskService h.LabelService = tt.fields.labelService @@ -262,7 +265,7 @@ func TestTaskHandler_handlePostTasks(t *testing.T) { w := httptest.NewRecorder() - h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService()) + h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService()) h.OrganizationService = mockOrgService() h.TaskService = tt.fields.taskService h.handlePostTask(w, r) @@ -367,7 +370,7 @@ func TestTaskHandler_handleGetRun(t *testing.T) { }, })) w := httptest.NewRecorder() - h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService()) + h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService()) h.OrganizationService = mockOrgService() h.TaskService = tt.fields.taskService h.handleGetRun(w, r) @@ -476,7 +479,7 @@ func TestTaskHandler_handleGetRuns(t *testing.T) { }, })) w := httptest.NewRecorder() - h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService()) + h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService()) h.OrganizationService = mockOrgService() h.TaskService = tt.fields.taskService h.handleGetRuns(w, r) @@ -497,3 +500,263 @@ func TestTaskHandler_handleGetRuns(t *testing.T) { }) } } + +func TestTaskHandler_NotFoundStatus(t *testing.T) { + // Ensure that the HTTP handlers return 404s for missing resources, and OKs for matching. + + im := inmem.NewService() + h := NewTaskHandler(im, im, zaptest.NewLogger(t), im) + h.OrganizationService = im + + o := platform.Organization{Name: "o"} + ctx := context.Background() + if err := h.OrganizationService.CreateOrganization(ctx, &o); err != nil { + t.Fatal(err) + } + + const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA) + + var ( + okTask = []interface{}{taskID} + okTaskRun = []interface{}{taskID, runID} + + notFoundTask = [][]interface{}{ + {taskID + 1}, + } + notFoundTaskRun = [][]interface{}{ + {taskID, runID + 1}, + {taskID + 1, runID}, + {taskID + 1, runID + 1}, + } + ) + + tcs := []struct { + name string + svc *mock.TaskService + method string + body string + pathFmt string + okPathArgs []interface{} + notFoundPathArgs [][]interface{} + }{ + { + name: "get task", + svc: &mock.TaskService{ + FindTaskByIDFn: func(_ context.Context, id platform.ID) (*platform.Task, error) { + if id == taskID { + return &platform.Task{ID: taskID, Organization: "o"}, nil + } + + return nil, backend.ErrTaskNotFound + }, + }, + method: http.MethodGet, + pathFmt: "/tasks/%s", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "update task", + svc: &mock.TaskService{ + UpdateTaskFn: func(_ context.Context, id platform.ID, _ platform.TaskUpdate) (*platform.Task, error) { + if id == taskID { + return &platform.Task{ID: taskID, Organization: "o"}, nil + } + + return nil, backend.ErrTaskNotFound + }, + }, + method: http.MethodPatch, + body: "{}", + pathFmt: "/tasks/%s", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "delete task", + svc: &mock.TaskService{ + DeleteTaskFn: func(_ context.Context, id platform.ID) error { + if id == taskID { + return nil + } + + return backend.ErrTaskNotFound + }, + }, + method: http.MethodDelete, + pathFmt: "/tasks/%s", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "get task logs", + svc: &mock.TaskService{ + FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) { + if *f.Task == taskID { + return nil, 0, nil + } + + return nil, 0, backend.ErrTaskNotFound + }, + }, + method: http.MethodGet, + pathFmt: "/tasks/%s/logs", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "get run logs", + svc: &mock.TaskService{ + FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) { + if *f.Task != taskID { + return nil, 0, backend.ErrTaskNotFound + } + if *f.Run != runID { + return nil, 0, backend.ErrRunNotFound + } + + return nil, 0, nil + }, + }, + method: http.MethodGet, + pathFmt: "/tasks/%s/runs/%s/logs", + okPathArgs: okTaskRun, + notFoundPathArgs: notFoundTaskRun, + }, + { + name: "get runs", + svc: &mock.TaskService{ + FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) { + if *f.Task != taskID { + return nil, 0, backend.ErrTaskNotFound + } + + return nil, 0, nil + }, + }, + method: http.MethodGet, + pathFmt: "/tasks/%s/runs", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "force run", + svc: &mock.TaskService{ + ForceRunFn: func(_ context.Context, tid platform.ID, _ int64) (*platform.Run, error) { + if tid != taskID { + return nil, backend.ErrTaskNotFound + } + + return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil + }, + }, + method: http.MethodPost, + body: "{}", + pathFmt: "/tasks/%s/runs", + okPathArgs: okTask, + notFoundPathArgs: notFoundTask, + }, + { + name: "get run", + svc: &mock.TaskService{ + FindRunByIDFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) { + if tid != taskID { + return nil, backend.ErrTaskNotFound + } + if rid != runID { + return nil, backend.ErrRunNotFound + } + + return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil + }, + }, + method: http.MethodGet, + pathFmt: "/tasks/%s/runs/%s", + okPathArgs: okTaskRun, + notFoundPathArgs: notFoundTaskRun, + }, + { + name: "retry run", + svc: &mock.TaskService{ + RetryRunFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) { + if tid != taskID { + return nil, backend.ErrTaskNotFound + } + if rid != runID { + return nil, backend.ErrRunNotFound + } + + return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil + }, + }, + method: http.MethodPost, + pathFmt: "/tasks/%s/runs/%s/retry", + okPathArgs: okTaskRun, + notFoundPathArgs: notFoundTaskRun, + }, + { + name: "cancel run", + svc: &mock.TaskService{ + CancelRunFn: func(_ context.Context, tid, rid platform.ID) error { + if tid != taskID { + return backend.ErrTaskNotFound + } + if rid != runID { + return backend.ErrRunNotFound + } + + return nil + }, + }, + method: http.MethodDelete, + pathFmt: "/tasks/%s/runs/%s", + okPathArgs: okTaskRun, + notFoundPathArgs: notFoundTaskRun, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + h.TaskService = tc.svc + + okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...) + t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body)) + + h.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode > 299 { + t.Errorf("expected OK, got %d", res.StatusCode) + b, _ := ioutil.ReadAll(res.Body) + t.Fatalf("body: %s", string(b)) + } + }) + + t.Run("mismatched ID", func(t *testing.T) { + for _, nfa := range tc.notFoundPathArgs { + path := fmt.Sprintf(tc.pathFmt, nfa...) + t.Run(tc.method+" "+path, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body)) + + h.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusNotFound { + t.Errorf("expected Not Found, got %d", res.StatusCode) + b, _ := ioutil.ReadAll(res.Body) + t.Fatalf("body: %s", string(b)) + } + }) + } + }) + }) + } +}