310 lines
6.9 KiB
Go
310 lines
6.9 KiB
Go
package http_test
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/gob"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/influxdata/influxdb/v2/kit/platform/errors"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
kithttp "github.com/influxdata/influxdb/v2/kit/transport/http"
|
|
"github.com/influxdata/influxdb/v2/pkg/testttp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_API(t *testing.T) {
|
|
t.Run("Decode", func(t *testing.T) {
|
|
t.Run("valid foo no errors", func(t *testing.T) {
|
|
expected := validatFoo{
|
|
Foo: "valid",
|
|
Bar: 10,
|
|
}
|
|
|
|
t.Run("json", func(t *testing.T) {
|
|
var api *kithttp.API // shows it is safe to use a nil value
|
|
|
|
var out validatFoo
|
|
err := api.DecodeJSON(encodeJSON(t, expected), &out)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if expected != out {
|
|
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
|
|
}
|
|
})
|
|
|
|
t.Run("gob", func(t *testing.T) {
|
|
var out validatFoo
|
|
err := kithttp.NewAPI().DecodeGob(encodeGob(t, expected), &out)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if expected != out {
|
|
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("unmarshals fine with ok error", func(t *testing.T) {
|
|
badFoo := validatFoo{
|
|
Foo: "",
|
|
Bar: 0,
|
|
}
|
|
|
|
t.Run("json", func(t *testing.T) {
|
|
var out validatFoo
|
|
err := kithttp.NewAPI().DecodeJSON(encodeJSON(t, badFoo), &out)
|
|
if err == nil {
|
|
t.Fatal("expected an err")
|
|
}
|
|
})
|
|
|
|
t.Run("gob", func(t *testing.T) {
|
|
var out validatFoo
|
|
err := kithttp.NewAPI().DecodeGob(encodeGob(t, badFoo), &out)
|
|
if err == nil {
|
|
t.Fatal("expected an err")
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("unmarshal error", func(t *testing.T) {
|
|
invalidBody := []byte("[}-{]")
|
|
|
|
var out validatFoo
|
|
err := kithttp.NewAPI().DecodeJSON(bytes.NewReader(invalidBody), &out)
|
|
if err == nil {
|
|
t.Fatal("expected an error")
|
|
}
|
|
})
|
|
|
|
t.Run("unmarshal err fn wraps unmarshalling error", func(t *testing.T) {
|
|
t.Run("json", func(t *testing.T) {
|
|
invalidBody := []byte("[}-{]")
|
|
|
|
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
|
|
|
|
var out validatFoo
|
|
err := api.DecodeJSON(bytes.NewReader(invalidBody), &out)
|
|
expectInfluxdbError(t, errors.EInvalid, err)
|
|
})
|
|
|
|
t.Run("gob", func(t *testing.T) {
|
|
invalidBody := []byte("[}-{]")
|
|
|
|
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
|
|
|
|
var out validatFoo
|
|
err := api.DecodeGob(bytes.NewReader(invalidBody), &out)
|
|
expectInfluxdbError(t, errors.EInvalid, err)
|
|
})
|
|
})
|
|
|
|
t.Run("ok error fn wraps ok error", func(t *testing.T) {
|
|
badFoo := validatFoo{Foo: ""}
|
|
|
|
t.Run("json", func(t *testing.T) {
|
|
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
|
|
|
|
var out validatFoo
|
|
err := api.DecodeJSON(encodeJSON(t, badFoo), &out)
|
|
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
|
|
})
|
|
|
|
t.Run("gob", func(t *testing.T) {
|
|
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
|
|
|
|
var out validatFoo
|
|
err := api.DecodeGob(encodeGob(t, badFoo), &out)
|
|
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
|
|
})
|
|
})
|
|
})
|
|
|
|
t.Run("Respond", func(t *testing.T) {
|
|
tests := []int{
|
|
http.StatusCreated,
|
|
http.StatusOK,
|
|
http.StatusNoContent,
|
|
http.StatusForbidden,
|
|
http.StatusInternalServerError,
|
|
}
|
|
|
|
for _, statusCode := range tests {
|
|
fn := func(t *testing.T) {
|
|
responder := kithttp.NewAPI()
|
|
|
|
svr := func(w http.ResponseWriter, r *http.Request) {
|
|
responder.Respond(w, r, statusCode, map[string]string{
|
|
"foo": "bar",
|
|
})
|
|
}
|
|
|
|
expectedBodyFn := func(body *bytes.Buffer) {
|
|
var resp map[string]string
|
|
require.NoError(t, json.NewDecoder(body).Decode(&resp))
|
|
assert.Equal(t, "bar", resp["foo"])
|
|
}
|
|
if statusCode == http.StatusNoContent {
|
|
expectedBodyFn = func(body *bytes.Buffer) {
|
|
require.Zero(t, body.Len())
|
|
}
|
|
}
|
|
|
|
testttp.
|
|
Get(t, "/foo").
|
|
Do(http.HandlerFunc(svr)).
|
|
ExpectStatus(statusCode).
|
|
ExpectBody(expectedBodyFn)
|
|
}
|
|
t.Run(http.StatusText(statusCode), fn)
|
|
}
|
|
})
|
|
|
|
t.Run("Err", func(t *testing.T) {
|
|
tests := []struct {
|
|
statusCode int
|
|
expectedErr *errors.Error
|
|
}{
|
|
{
|
|
statusCode: http.StatusBadRequest,
|
|
expectedErr: &errors.Error{
|
|
Code: errors.EInvalid,
|
|
Msg: "failed to unmarshal",
|
|
},
|
|
},
|
|
{
|
|
statusCode: http.StatusForbidden,
|
|
expectedErr: &errors.Error{
|
|
Code: errors.EForbidden,
|
|
Msg: "forbidden",
|
|
},
|
|
},
|
|
{
|
|
statusCode: http.StatusUnprocessableEntity,
|
|
expectedErr: &errors.Error{
|
|
Code: errors.EUnprocessableEntity,
|
|
Msg: "failed validation",
|
|
},
|
|
},
|
|
{
|
|
statusCode: http.StatusInternalServerError,
|
|
expectedErr: &errors.Error{
|
|
Code: errors.EInternal,
|
|
Msg: "internal error",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
fn := func(t *testing.T) {
|
|
responder := kithttp.NewAPI()
|
|
|
|
svr := func(w http.ResponseWriter, r *http.Request) {
|
|
responder.Err(w, r, tt.expectedErr)
|
|
}
|
|
|
|
testttp.
|
|
Get(t, "/foo").
|
|
Do(http.HandlerFunc(svr)).
|
|
ExpectStatus(tt.statusCode).
|
|
ExpectBody(func(body *bytes.Buffer) {
|
|
var err kithttp.ErrBody
|
|
require.NoError(t, json.NewDecoder(body).Decode(&err))
|
|
assert.Equal(t, tt.expectedErr.Msg, err.Msg)
|
|
assert.Equal(t, tt.expectedErr.Code, err.Code)
|
|
})
|
|
}
|
|
t.Run(http.StatusText(tt.statusCode), fn)
|
|
}
|
|
})
|
|
}
|
|
|
|
func expectInfluxdbError(t *testing.T, expectedCode string, err error) {
|
|
t.Helper()
|
|
|
|
if err == nil {
|
|
t.Fatal("expected an error")
|
|
}
|
|
iErr, ok := err.(*errors.Error)
|
|
if !ok {
|
|
t.Fatalf("expected an influxdb error; got=%#v", err)
|
|
}
|
|
|
|
if got := iErr.Code; expectedCode != got {
|
|
t.Fatalf("unexpected error code; expected=%s got=%s", expectedCode, got)
|
|
}
|
|
}
|
|
|
|
func encodeGob(t *testing.T, v interface{}) io.Reader {
|
|
t.Helper()
|
|
|
|
var buf bytes.Buffer
|
|
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &buf
|
|
}
|
|
|
|
func encodeJSON(t *testing.T, v interface{}) io.Reader {
|
|
t.Helper()
|
|
|
|
var buf bytes.Buffer
|
|
if err := json.NewEncoder(&buf).Encode(v); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &buf
|
|
}
|
|
|
|
func okErrFn(err error) error {
|
|
return &errors.Error{
|
|
Code: errors.EUnprocessableEntity,
|
|
Msg: "failed validation",
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
func unmarshalErrFn(encoding string, err error) error {
|
|
return &errors.Error{
|
|
Code: errors.EInvalid,
|
|
Msg: fmt.Sprintf("invalid %s request body", encoding),
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
type validatFoo struct {
|
|
Foo string `gob:"foo"`
|
|
Bar int `gob:"bar"`
|
|
}
|
|
|
|
func (v *validatFoo) OK() error {
|
|
var errs multiErr
|
|
if v.Foo == "" {
|
|
errs = append(errs, "foo must be at least 1 char")
|
|
}
|
|
if v.Bar < 0 {
|
|
errs = append(errs, "bar must be a positive real number")
|
|
}
|
|
return errs.toError()
|
|
}
|
|
|
|
type multiErr []string
|
|
|
|
func (m multiErr) toError() error {
|
|
if len(m) > 0 {
|
|
return m
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m multiErr) Error() string {
|
|
return strings.Join(m, "; ")
|
|
}
|