influxdb/kit/transport/http/api_test.go

310 lines
6.9 KiB
Go

package http_test
import (
"bytes"
"encoding/gob"
"encoding/json"
"fmt"
"github.com/influxdata/influxdb/v2/kit/platform/errors"
"io"
"net/http"
"strings"
"testing"
kithttp "github.com/influxdata/influxdb/v2/kit/transport/http"
"github.com/influxdata/influxdb/v2/pkg/testttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_API(t *testing.T) {
t.Run("Decode", func(t *testing.T) {
t.Run("valid foo no errors", func(t *testing.T) {
expected := validatFoo{
Foo: "valid",
Bar: 10,
}
t.Run("json", func(t *testing.T) {
var api *kithttp.API // shows it is safe to use a nil value
var out validatFoo
err := api.DecodeJSON(encodeJSON(t, expected), &out)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if expected != out {
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
}
})
t.Run("gob", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeGob(encodeGob(t, expected), &out)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if expected != out {
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
}
})
})
t.Run("unmarshals fine with ok error", func(t *testing.T) {
badFoo := validatFoo{
Foo: "",
Bar: 0,
}
t.Run("json", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeJSON(encodeJSON(t, badFoo), &out)
if err == nil {
t.Fatal("expected an err")
}
})
t.Run("gob", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeGob(encodeGob(t, badFoo), &out)
if err == nil {
t.Fatal("expected an err")
}
})
})
t.Run("unmarshal error", func(t *testing.T) {
invalidBody := []byte("[}-{]")
var out validatFoo
err := kithttp.NewAPI().DecodeJSON(bytes.NewReader(invalidBody), &out)
if err == nil {
t.Fatal("expected an error")
}
})
t.Run("unmarshal err fn wraps unmarshalling error", func(t *testing.T) {
t.Run("json", func(t *testing.T) {
invalidBody := []byte("[}-{]")
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
var out validatFoo
err := api.DecodeJSON(bytes.NewReader(invalidBody), &out)
expectInfluxdbError(t, errors.EInvalid, err)
})
t.Run("gob", func(t *testing.T) {
invalidBody := []byte("[}-{]")
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
var out validatFoo
err := api.DecodeGob(bytes.NewReader(invalidBody), &out)
expectInfluxdbError(t, errors.EInvalid, err)
})
})
t.Run("ok error fn wraps ok error", func(t *testing.T) {
badFoo := validatFoo{Foo: ""}
t.Run("json", func(t *testing.T) {
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
var out validatFoo
err := api.DecodeJSON(encodeJSON(t, badFoo), &out)
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
})
t.Run("gob", func(t *testing.T) {
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
var out validatFoo
err := api.DecodeGob(encodeGob(t, badFoo), &out)
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
})
})
})
t.Run("Respond", func(t *testing.T) {
tests := []int{
http.StatusCreated,
http.StatusOK,
http.StatusNoContent,
http.StatusForbidden,
http.StatusInternalServerError,
}
for _, statusCode := range tests {
fn := func(t *testing.T) {
responder := kithttp.NewAPI()
svr := func(w http.ResponseWriter, r *http.Request) {
responder.Respond(w, r, statusCode, map[string]string{
"foo": "bar",
})
}
expectedBodyFn := func(body *bytes.Buffer) {
var resp map[string]string
require.NoError(t, json.NewDecoder(body).Decode(&resp))
assert.Equal(t, "bar", resp["foo"])
}
if statusCode == http.StatusNoContent {
expectedBodyFn = func(body *bytes.Buffer) {
require.Zero(t, body.Len())
}
}
testttp.
Get(t, "/foo").
Do(http.HandlerFunc(svr)).
ExpectStatus(statusCode).
ExpectBody(expectedBodyFn)
}
t.Run(http.StatusText(statusCode), fn)
}
})
t.Run("Err", func(t *testing.T) {
tests := []struct {
statusCode int
expectedErr *errors.Error
}{
{
statusCode: http.StatusBadRequest,
expectedErr: &errors.Error{
Code: errors.EInvalid,
Msg: "failed to unmarshal",
},
},
{
statusCode: http.StatusForbidden,
expectedErr: &errors.Error{
Code: errors.EForbidden,
Msg: "forbidden",
},
},
{
statusCode: http.StatusUnprocessableEntity,
expectedErr: &errors.Error{
Code: errors.EUnprocessableEntity,
Msg: "failed validation",
},
},
{
statusCode: http.StatusInternalServerError,
expectedErr: &errors.Error{
Code: errors.EInternal,
Msg: "internal error",
},
},
}
for _, tt := range tests {
fn := func(t *testing.T) {
responder := kithttp.NewAPI()
svr := func(w http.ResponseWriter, r *http.Request) {
responder.Err(w, r, tt.expectedErr)
}
testttp.
Get(t, "/foo").
Do(http.HandlerFunc(svr)).
ExpectStatus(tt.statusCode).
ExpectBody(func(body *bytes.Buffer) {
var err kithttp.ErrBody
require.NoError(t, json.NewDecoder(body).Decode(&err))
assert.Equal(t, tt.expectedErr.Msg, err.Msg)
assert.Equal(t, tt.expectedErr.Code, err.Code)
})
}
t.Run(http.StatusText(tt.statusCode), fn)
}
})
}
func expectInfluxdbError(t *testing.T, expectedCode string, err error) {
t.Helper()
if err == nil {
t.Fatal("expected an error")
}
iErr, ok := err.(*errors.Error)
if !ok {
t.Fatalf("expected an influxdb error; got=%#v", err)
}
if got := iErr.Code; expectedCode != got {
t.Fatalf("unexpected error code; expected=%s got=%s", expectedCode, got)
}
}
func encodeGob(t *testing.T, v interface{}) io.Reader {
t.Helper()
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
t.Fatal(err)
}
return &buf
}
func encodeJSON(t *testing.T, v interface{}) io.Reader {
t.Helper()
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(v); err != nil {
t.Fatal(err)
}
return &buf
}
func okErrFn(err error) error {
return &errors.Error{
Code: errors.EUnprocessableEntity,
Msg: "failed validation",
Err: err,
}
}
func unmarshalErrFn(encoding string, err error) error {
return &errors.Error{
Code: errors.EInvalid,
Msg: fmt.Sprintf("invalid %s request body", encoding),
Err: err,
}
}
type validatFoo struct {
Foo string `gob:"foo"`
Bar int `gob:"bar"`
}
func (v *validatFoo) OK() error {
var errs multiErr
if v.Foo == "" {
errs = append(errs, "foo must be at least 1 char")
}
if v.Bar < 0 {
errs = append(errs, "bar must be a positive real number")
}
return errs.toError()
}
type multiErr []string
func (m multiErr) toError() error {
if len(m) > 0 {
return m
}
return nil
}
func (m multiErr) Error() string {
return strings.Join(m, "; ")
}