fix(http): return 404 when task or run is not found

For an operation that looks up a task or a run, when that operation
fails, only set the status to 404 if that operation explicitly returns
ErrTaskNotFound or ErrRunNotFound.

It's possible that the operation could fail for a reason other than the
ID being invalid: for example, if there was an IO error preventing the
lookup from succeeding.

Harden that behavior with tests for the task handler.

Closes #11589.
pull/11628/head
Mark Rushakoff 2019-01-25 15:17:30 -08:00 committed by Mark Rushakoff
parent f0371e6716
commit 3f0e40812e
2 changed files with 309 additions and 32 deletions

View File

@ -462,10 +462,13 @@ func (h *TaskHandler) handleUpdateTask(w http.ResponseWriter, r *http.Request) {
}
task, err := h.TaskService.UpdateTask(ctx, req.TaskID, req.Update)
if err != nil {
err = &platform.Error{
err := &platform.Error{
Err: err,
Msg: "failed to update task",
}
if err.Err == backend.ErrTaskNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
}
@ -541,10 +544,13 @@ func (h *TaskHandler) handleDeleteTask(w http.ResponseWriter, r *http.Request) {
}
if err := h.TaskService.DeleteTask(ctx, req.TaskID); err != nil {
err = &platform.Error{
err := &platform.Error{
Err: err,
Msg: "failed to delete task",
}
if err.Err == backend.ErrTaskNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
}
@ -592,10 +598,13 @@ func (h *TaskHandler) handleGetLogs(w http.ResponseWriter, r *http.Request) {
logs, _, err := h.TaskService.FindLogs(ctx, req.filter)
if err != nil {
err = &platform.Error{
err := &platform.Error{
Err: err,
Msg: "failed to find task logs",
}
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
}
@ -671,10 +680,12 @@ func (h *TaskHandler) handleGetRuns(w http.ResponseWriter, r *http.Request) {
runs, _, err := h.TaskService.FindRuns(ctx, req.filter)
if err != nil {
err = &platform.Error{
Err: err,
Code: platform.EInvalid,
Msg: "failed to find runs",
err := &platform.Error{
Err: err,
Msg: "failed to find runs",
}
if err.Err == backend.ErrTaskNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
@ -792,12 +803,12 @@ func (h *TaskHandler) handleForceRun(w http.ResponseWriter, r *http.Request) {
run, err := h.TaskService.ForceRun(ctx, req.TaskID, req.Timestamp)
if err != nil {
if err == backend.ErrRunNotFound {
err = &platform.Error{
Code: platform.ENotFound,
Msg: "failed to force run",
Err: err,
}
err := &platform.Error{
Err: err,
Msg: "failed to force run",
}
if err.Err == backend.ErrTaskNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
@ -868,12 +879,12 @@ func (h *TaskHandler) handleGetRun(w http.ResponseWriter, r *http.Request) {
run, err := h.TaskService.FindRunByID(ctx, req.TaskID, req.RunID)
if err != nil {
if err == backend.ErrRunNotFound {
err = &platform.Error{
Err: err,
Msg: "failed to find run",
Code: platform.ENotFound,
}
err := &platform.Error{
Err: err,
Msg: "failed to find run",
}
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
@ -974,10 +985,13 @@ func (h *TaskHandler) handleCancelRun(w http.ResponseWriter, r *http.Request) {
err = h.TaskService.CancelRun(ctx, req.TaskID, req.RunID)
if err != nil {
err = &platform.Error{
err := &platform.Error{
Err: err,
Msg: "failed to cancel run",
}
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return
}
@ -999,12 +1013,12 @@ func (h *TaskHandler) handleRetryRun(w http.ResponseWriter, r *http.Request) {
run, err := h.TaskService.RetryRun(ctx, req.TaskID, req.RunID)
if err != nil {
if err == backend.ErrRunNotFound {
err = &platform.Error{
Code: platform.ENotFound,
Msg: "failed to retry run",
Err: err,
}
err := &platform.Error{
Err: err,
Msg: "failed to retry run",
}
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
err.Code = platform.ENotFound
}
EncodeError(ctx, err, w)
return

View File

@ -4,19 +4,22 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
platform "github.com/influxdata/influxdb"
pcontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/logger"
"github.com/influxdata/influxdb/inmem"
"github.com/influxdata/influxdb/mock"
_ "github.com/influxdata/influxdb/query/builtin"
"github.com/influxdata/influxdb/task/backend"
platformtesting "github.com/influxdata/influxdb/testing"
"github.com/julienschmidt/httprouter"
"go.uber.org/zap/zaptest"
)
func mockOrgService() platform.OrganizationService {
@ -161,7 +164,7 @@ func TestTaskHandler_handleGetTasks(t *testing.T) {
r := httptest.NewRequest("GET", "http://any.url", nil)
w := httptest.NewRecorder()
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
h.OrganizationService = mockOrgService()
h.TaskService = tt.fields.taskService
h.LabelService = tt.fields.labelService
@ -262,7 +265,7 @@ func TestTaskHandler_handlePostTasks(t *testing.T) {
w := httptest.NewRecorder()
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
h.OrganizationService = mockOrgService()
h.TaskService = tt.fields.taskService
h.handlePostTask(w, r)
@ -367,7 +370,7 @@ func TestTaskHandler_handleGetRun(t *testing.T) {
},
}))
w := httptest.NewRecorder()
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
h.OrganizationService = mockOrgService()
h.TaskService = tt.fields.taskService
h.handleGetRun(w, r)
@ -476,7 +479,7 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
},
}))
w := httptest.NewRecorder()
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
h.OrganizationService = mockOrgService()
h.TaskService = tt.fields.taskService
h.handleGetRuns(w, r)
@ -497,3 +500,263 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
})
}
}
func TestTaskHandler_NotFoundStatus(t *testing.T) {
// Ensure that the HTTP handlers return 404s for missing resources, and OKs for matching.
im := inmem.NewService()
h := NewTaskHandler(im, im, zaptest.NewLogger(t), im)
h.OrganizationService = im
o := platform.Organization{Name: "o"}
ctx := context.Background()
if err := h.OrganizationService.CreateOrganization(ctx, &o); err != nil {
t.Fatal(err)
}
const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA)
var (
okTask = []interface{}{taskID}
okTaskRun = []interface{}{taskID, runID}
notFoundTask = [][]interface{}{
{taskID + 1},
}
notFoundTaskRun = [][]interface{}{
{taskID, runID + 1},
{taskID + 1, runID},
{taskID + 1, runID + 1},
}
)
tcs := []struct {
name string
svc *mock.TaskService
method string
body string
pathFmt string
okPathArgs []interface{}
notFoundPathArgs [][]interface{}
}{
{
name: "get task",
svc: &mock.TaskService{
FindTaskByIDFn: func(_ context.Context, id platform.ID) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, backend.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "update task",
svc: &mock.TaskService{
UpdateTaskFn: func(_ context.Context, id platform.ID, _ platform.TaskUpdate) (*platform.Task, error) {
if id == taskID {
return &platform.Task{ID: taskID, Organization: "o"}, nil
}
return nil, backend.ErrTaskNotFound
},
},
method: http.MethodPatch,
body: "{}",
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "delete task",
svc: &mock.TaskService{
DeleteTaskFn: func(_ context.Context, id platform.ID) error {
if id == taskID {
return nil
}
return backend.ErrTaskNotFound
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get task logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if *f.Task == taskID {
return nil, 0, nil
}
return nil, 0, backend.ErrTaskNotFound
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/logs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run logs",
svc: &mock.TaskService{
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
if *f.Task != taskID {
return nil, 0, backend.ErrTaskNotFound
}
if *f.Run != runID {
return nil, 0, backend.ErrRunNotFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s/logs",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "get runs",
svc: &mock.TaskService{
FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
if *f.Task != taskID {
return nil, 0, backend.ErrTaskNotFound
}
return nil, 0, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "force run",
svc: &mock.TaskService{
ForceRunFn: func(_ context.Context, tid platform.ID, _ int64) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
body: "{}",
pathFmt: "/tasks/%s/runs",
okPathArgs: okTask,
notFoundPathArgs: notFoundTask,
},
{
name: "get run",
svc: &mock.TaskService{
FindRunByIDFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
if rid != runID {
return nil, backend.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodGet,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "retry run",
svc: &mock.TaskService{
RetryRunFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
if tid != taskID {
return nil, backend.ErrTaskNotFound
}
if rid != runID {
return nil, backend.ErrRunNotFound
}
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
},
},
method: http.MethodPost,
pathFmt: "/tasks/%s/runs/%s/retry",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
{
name: "cancel run",
svc: &mock.TaskService{
CancelRunFn: func(_ context.Context, tid, rid platform.ID) error {
if tid != taskID {
return backend.ErrTaskNotFound
}
if rid != runID {
return backend.ErrRunNotFound
}
return nil
},
},
method: http.MethodDelete,
pathFmt: "/tasks/%s/runs/%s",
okPathArgs: okTaskRun,
notFoundPathArgs: notFoundTaskRun,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
h.TaskService = tc.svc
okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...)
t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body))
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
t.Errorf("expected OK, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
t.Run("mismatched ID", func(t *testing.T) {
for _, nfa := range tc.notFoundPathArgs {
path := fmt.Sprintf(tc.pathFmt, nfa...)
t.Run(tc.method+" "+path, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body))
h.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusNotFound {
t.Errorf("expected Not Found, got %d", res.StatusCode)
b, _ := ioutil.ReadAll(res.Body)
t.Fatalf("body: %s", string(b))
}
})
}
})
})
}
}