fix(http): return 404 when task or run is not found
For an operation that looks up a task or a run, when that operation fails, only set the status to 404 if that operation explicitly returns ErrTaskNotFound or ErrRunNotFound. It's possible that the operation could fail for a reason other than the ID being invalid: for example, if there was an IO error preventing the lookup from succeeding. Harden that behavior with tests for the task handler. Closes #11589.pull/11628/head
parent
f0371e6716
commit
3f0e40812e
|
@ -462,10 +462,13 @@ func (h *TaskHandler) handleUpdateTask(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
task, err := h.TaskService.UpdateTask(ctx, req.TaskID, req.Update)
|
||||
if err != nil {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to update task",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
@ -541,10 +544,13 @@ func (h *TaskHandler) handleDeleteTask(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := h.TaskService.DeleteTask(ctx, req.TaskID); err != nil {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to delete task",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
@ -592,10 +598,13 @@ func (h *TaskHandler) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
logs, _, err := h.TaskService.FindLogs(ctx, req.filter)
|
||||
if err != nil {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to find task logs",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
@ -671,11 +680,13 @@ func (h *TaskHandler) handleGetRuns(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
runs, _, err := h.TaskService.FindRuns(ctx, req.filter)
|
||||
if err != nil {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Code: platform.EInvalid,
|
||||
Msg: "failed to find runs",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
@ -792,12 +803,12 @@ func (h *TaskHandler) handleForceRun(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
run, err := h.TaskService.ForceRun(ctx, req.TaskID, req.Timestamp)
|
||||
if err != nil {
|
||||
if err == backend.ErrRunNotFound {
|
||||
err = &platform.Error{
|
||||
Code: platform.ENotFound,
|
||||
Msg: "failed to force run",
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to force run",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
|
@ -868,12 +879,12 @@ func (h *TaskHandler) handleGetRun(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
run, err := h.TaskService.FindRunByID(ctx, req.TaskID, req.RunID)
|
||||
if err != nil {
|
||||
if err == backend.ErrRunNotFound {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to find run",
|
||||
Code: platform.ENotFound,
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
|
@ -974,10 +985,13 @@ func (h *TaskHandler) handleCancelRun(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
err = h.TaskService.CancelRun(ctx, req.TaskID, req.RunID)
|
||||
if err != nil {
|
||||
err = &platform.Error{
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to cancel run",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
@ -999,12 +1013,12 @@ func (h *TaskHandler) handleRetryRun(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
run, err := h.TaskService.RetryRun(ctx, req.TaskID, req.RunID)
|
||||
if err != nil {
|
||||
if err == backend.ErrRunNotFound {
|
||||
err = &platform.Error{
|
||||
Code: platform.ENotFound,
|
||||
Msg: "failed to retry run",
|
||||
err := &platform.Error{
|
||||
Err: err,
|
||||
Msg: "failed to retry run",
|
||||
}
|
||||
if err.Err == backend.ErrTaskNotFound || err.Err == backend.ErrRunNotFound {
|
||||
err.Code = platform.ENotFound
|
||||
}
|
||||
EncodeError(ctx, err, w)
|
||||
return
|
||||
|
|
|
@ -4,19 +4,22 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
platform "github.com/influxdata/influxdb"
|
||||
pcontext "github.com/influxdata/influxdb/context"
|
||||
"github.com/influxdata/influxdb/logger"
|
||||
"github.com/influxdata/influxdb/inmem"
|
||||
"github.com/influxdata/influxdb/mock"
|
||||
_ "github.com/influxdata/influxdb/query/builtin"
|
||||
"github.com/influxdata/influxdb/task/backend"
|
||||
platformtesting "github.com/influxdata/influxdb/testing"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func mockOrgService() platform.OrganizationService {
|
||||
|
@ -161,7 +164,7 @@ func TestTaskHandler_handleGetTasks(t *testing.T) {
|
|||
r := httptest.NewRequest("GET", "http://any.url", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
|
||||
h.OrganizationService = mockOrgService()
|
||||
h.TaskService = tt.fields.taskService
|
||||
h.LabelService = tt.fields.labelService
|
||||
|
@ -262,7 +265,7 @@ func TestTaskHandler_handlePostTasks(t *testing.T) {
|
|||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
|
||||
h.OrganizationService = mockOrgService()
|
||||
h.TaskService = tt.fields.taskService
|
||||
h.handlePostTask(w, r)
|
||||
|
@ -367,7 +370,7 @@ func TestTaskHandler_handleGetRun(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
w := httptest.NewRecorder()
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
|
||||
h.OrganizationService = mockOrgService()
|
||||
h.TaskService = tt.fields.taskService
|
||||
h.handleGetRun(w, r)
|
||||
|
@ -476,7 +479,7 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
w := httptest.NewRecorder()
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), logger.New(os.Stdout), mock.NewUserService())
|
||||
h := NewTaskHandler(mock.NewUserResourceMappingService(), mock.NewLabelService(), zaptest.NewLogger(t), mock.NewUserService())
|
||||
h.OrganizationService = mockOrgService()
|
||||
h.TaskService = tt.fields.taskService
|
||||
h.handleGetRuns(w, r)
|
||||
|
@ -497,3 +500,263 @@ func TestTaskHandler_handleGetRuns(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_NotFoundStatus(t *testing.T) {
|
||||
// Ensure that the HTTP handlers return 404s for missing resources, and OKs for matching.
|
||||
|
||||
im := inmem.NewService()
|
||||
h := NewTaskHandler(im, im, zaptest.NewLogger(t), im)
|
||||
h.OrganizationService = im
|
||||
|
||||
o := platform.Organization{Name: "o"}
|
||||
ctx := context.Background()
|
||||
if err := h.OrganizationService.CreateOrganization(ctx, &o); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const taskID, runID = platform.ID(0xCCCCCC), platform.ID(0xAAAAAA)
|
||||
|
||||
var (
|
||||
okTask = []interface{}{taskID}
|
||||
okTaskRun = []interface{}{taskID, runID}
|
||||
|
||||
notFoundTask = [][]interface{}{
|
||||
{taskID + 1},
|
||||
}
|
||||
notFoundTaskRun = [][]interface{}{
|
||||
{taskID, runID + 1},
|
||||
{taskID + 1, runID},
|
||||
{taskID + 1, runID + 1},
|
||||
}
|
||||
)
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
svc *mock.TaskService
|
||||
method string
|
||||
body string
|
||||
pathFmt string
|
||||
okPathArgs []interface{}
|
||||
notFoundPathArgs [][]interface{}
|
||||
}{
|
||||
{
|
||||
name: "get task",
|
||||
svc: &mock.TaskService{
|
||||
FindTaskByIDFn: func(_ context.Context, id platform.ID) (*platform.Task, error) {
|
||||
if id == taskID {
|
||||
return &platform.Task{ID: taskID, Organization: "o"}, nil
|
||||
}
|
||||
|
||||
return nil, backend.ErrTaskNotFound
|
||||
},
|
||||
},
|
||||
method: http.MethodGet,
|
||||
pathFmt: "/tasks/%s",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "update task",
|
||||
svc: &mock.TaskService{
|
||||
UpdateTaskFn: func(_ context.Context, id platform.ID, _ platform.TaskUpdate) (*platform.Task, error) {
|
||||
if id == taskID {
|
||||
return &platform.Task{ID: taskID, Organization: "o"}, nil
|
||||
}
|
||||
|
||||
return nil, backend.ErrTaskNotFound
|
||||
},
|
||||
},
|
||||
method: http.MethodPatch,
|
||||
body: "{}",
|
||||
pathFmt: "/tasks/%s",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "delete task",
|
||||
svc: &mock.TaskService{
|
||||
DeleteTaskFn: func(_ context.Context, id platform.ID) error {
|
||||
if id == taskID {
|
||||
return nil
|
||||
}
|
||||
|
||||
return backend.ErrTaskNotFound
|
||||
},
|
||||
},
|
||||
method: http.MethodDelete,
|
||||
pathFmt: "/tasks/%s",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "get task logs",
|
||||
svc: &mock.TaskService{
|
||||
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
|
||||
if *f.Task == taskID {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
return nil, 0, backend.ErrTaskNotFound
|
||||
},
|
||||
},
|
||||
method: http.MethodGet,
|
||||
pathFmt: "/tasks/%s/logs",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "get run logs",
|
||||
svc: &mock.TaskService{
|
||||
FindLogsFn: func(_ context.Context, f platform.LogFilter) ([]*platform.Log, int, error) {
|
||||
if *f.Task != taskID {
|
||||
return nil, 0, backend.ErrTaskNotFound
|
||||
}
|
||||
if *f.Run != runID {
|
||||
return nil, 0, backend.ErrRunNotFound
|
||||
}
|
||||
|
||||
return nil, 0, nil
|
||||
},
|
||||
},
|
||||
method: http.MethodGet,
|
||||
pathFmt: "/tasks/%s/runs/%s/logs",
|
||||
okPathArgs: okTaskRun,
|
||||
notFoundPathArgs: notFoundTaskRun,
|
||||
},
|
||||
{
|
||||
name: "get runs",
|
||||
svc: &mock.TaskService{
|
||||
FindRunsFn: func(_ context.Context, f platform.RunFilter) ([]*platform.Run, int, error) {
|
||||
if *f.Task != taskID {
|
||||
return nil, 0, backend.ErrTaskNotFound
|
||||
}
|
||||
|
||||
return nil, 0, nil
|
||||
},
|
||||
},
|
||||
method: http.MethodGet,
|
||||
pathFmt: "/tasks/%s/runs",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "force run",
|
||||
svc: &mock.TaskService{
|
||||
ForceRunFn: func(_ context.Context, tid platform.ID, _ int64) (*platform.Run, error) {
|
||||
if tid != taskID {
|
||||
return nil, backend.ErrTaskNotFound
|
||||
}
|
||||
|
||||
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
|
||||
},
|
||||
},
|
||||
method: http.MethodPost,
|
||||
body: "{}",
|
||||
pathFmt: "/tasks/%s/runs",
|
||||
okPathArgs: okTask,
|
||||
notFoundPathArgs: notFoundTask,
|
||||
},
|
||||
{
|
||||
name: "get run",
|
||||
svc: &mock.TaskService{
|
||||
FindRunByIDFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
|
||||
if tid != taskID {
|
||||
return nil, backend.ErrTaskNotFound
|
||||
}
|
||||
if rid != runID {
|
||||
return nil, backend.ErrRunNotFound
|
||||
}
|
||||
|
||||
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
|
||||
},
|
||||
},
|
||||
method: http.MethodGet,
|
||||
pathFmt: "/tasks/%s/runs/%s",
|
||||
okPathArgs: okTaskRun,
|
||||
notFoundPathArgs: notFoundTaskRun,
|
||||
},
|
||||
{
|
||||
name: "retry run",
|
||||
svc: &mock.TaskService{
|
||||
RetryRunFn: func(_ context.Context, tid, rid platform.ID) (*platform.Run, error) {
|
||||
if tid != taskID {
|
||||
return nil, backend.ErrTaskNotFound
|
||||
}
|
||||
if rid != runID {
|
||||
return nil, backend.ErrRunNotFound
|
||||
}
|
||||
|
||||
return &platform.Run{ID: runID, TaskID: taskID, Status: backend.RunScheduled.String()}, nil
|
||||
},
|
||||
},
|
||||
method: http.MethodPost,
|
||||
pathFmt: "/tasks/%s/runs/%s/retry",
|
||||
okPathArgs: okTaskRun,
|
||||
notFoundPathArgs: notFoundTaskRun,
|
||||
},
|
||||
{
|
||||
name: "cancel run",
|
||||
svc: &mock.TaskService{
|
||||
CancelRunFn: func(_ context.Context, tid, rid platform.ID) error {
|
||||
if tid != taskID {
|
||||
return backend.ErrTaskNotFound
|
||||
}
|
||||
if rid != runID {
|
||||
return backend.ErrRunNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
method: http.MethodDelete,
|
||||
pathFmt: "/tasks/%s/runs/%s",
|
||||
okPathArgs: okTaskRun,
|
||||
notFoundPathArgs: notFoundTaskRun,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h.TaskService = tc.svc
|
||||
|
||||
okPath := fmt.Sprintf(tc.pathFmt, tc.okPathArgs...)
|
||||
t.Run("matching ID: "+tc.method+" "+okPath, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+okPath, strings.NewReader(tc.body))
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode > 299 {
|
||||
t.Errorf("expected OK, got %d", res.StatusCode)
|
||||
b, _ := ioutil.ReadAll(res.Body)
|
||||
t.Fatalf("body: %s", string(b))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mismatched ID", func(t *testing.T) {
|
||||
for _, nfa := range tc.notFoundPathArgs {
|
||||
path := fmt.Sprintf(tc.pathFmt, nfa...)
|
||||
t.Run(tc.method+" "+path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(tc.method, "http://task.example/api/v2"+path, strings.NewReader(tc.body))
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected Not Found, got %d", res.StatusCode)
|
||||
b, _ := ioutil.ReadAll(res.Body)
|
||||
t.Fatalf("body: %s", string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue