influxdb/pkg/httpc/req.go

199 lines
4.5 KiB
Go

package httpc
import (
"compress/gzip"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/influxdata/influxdb/v2"
)
const (
headerContentType = "Content-Type"
headerContentEncoding = "Content-Encoding"
)
// Req is a request type.
type Req struct {
client doer
req *http.Request
authFn func(*http.Request)
decodeFn func(*http.Response) error
respFn func(*http.Response) error
statusFn func(*http.Response) error
err error
}
// Accept sets the Accept header to the provided content type on the request.
func (r *Req) Accept(contentType string) *Req {
return r.Header("Accept", contentType)
}
// ContentType sets the Content-Type header to the provided content type on the request.
func (r *Req) ContentType(contentType string) *Req {
return r.Header("Content-Type", contentType)
}
// Decode sets the decoding functionality for the request. All Decode calls are called
// after the status and response functions are called. Decoding will not happen if error
// encountered in the status check.
func (r *Req) Decode(fn func(resp *http.Response) error) *Req {
if r.err != nil {
return r
}
r.decodeFn = fn
return r
}
// DecodeGob sets the decoding functionality to decode gob for the request.
func (r *Req) DecodeGob(v interface{}) *Req {
return r.Decode(func(resp *http.Response) error {
r := decodeReader(resp.Body, resp.Header)
return gob.NewDecoder(r).Decode(v)
})
}
// DecodeJSON sets the decoding functionality to decode json for the request.
func (r *Req) DecodeJSON(v interface{}) *Req {
return r.Decode(func(resp *http.Response) error {
r := decodeReader(resp.Body, resp.Header)
return json.NewDecoder(r).Decode(v)
})
}
// Header adds the header to the http request.
func (r *Req) Header(k, v string) *Req {
if r.err != nil {
return r
}
r.req.Header.Add(k, v)
return r
}
// Headers adds all the headers to the http request.
func (r *Req) Headers(m map[string][]string) *Req {
if r.err != nil {
return r
}
for header, vals := range m {
if header == "" {
continue
}
for _, v := range vals {
r = r.Header(header, v)
}
}
return r
}
// QueryParams adds the query params to the http request.
func (r *Req) QueryParams(pairs ...[2]string) *Req {
if r.err != nil || len(pairs) == 0 {
return r
}
params := r.req.URL.Query()
for _, p := range pairs {
params.Add(p[0], p[1])
}
r.req.URL.RawQuery = params.Encode()
return r
}
// RespFn provides a means to inspect the entire http response. This function runs first
// before the status and decode functions are called.
func (r *Req) RespFn(fn func(*http.Response) error) *Req {
r.respFn = fn
return r
}
// StatusFn sets a status check function. This runs after the resp func
// but before the decode fn.
func (r *Req) StatusFn(fn func(*http.Response) error) *Req {
r.statusFn = fn
return r
}
// Do makes the HTTP request. Any errors that had been encountered in
// the lifetime of the Req type will be returned here first, in place of
// the call. This makes it safe to call Do at anytime.
func (r *Req) Do(ctx context.Context) error {
if r.err != nil {
return r.err
}
r.authFn(r.req)
// TODO(@jsteenb2): wrap do with retry/backoff policy.
return r.do(ctx)
}
func (r *Req) do(ctx context.Context) error {
resp, err := r.client.Do(r.req.WithContext(ctx))
if err != nil {
return err
}
defer func() {
io.Copy(ioutil.Discard, resp.Body) // drain body completely
resp.Body.Close()
}()
if r.respFn != nil {
if err := r.respFn(resp); err != nil {
return err
}
}
if r.statusFn != nil {
if err := r.statusFn(resp); err != nil {
return err
}
}
if r.decodeFn != nil {
if err := r.decodeFn(resp); err != nil {
return &influxdb.Error{
Code: influxdb.EInvalid,
Err: err,
}
}
}
return nil
}
// StatusIn validates the status code matches one of the provided statuses.
func StatusIn(code int, rest ...int) func(*http.Response) error {
return func(resp *http.Response) error {
for _, code := range append(rest, code) {
if code == resp.StatusCode {
return nil
}
}
return fmt.Errorf("recieved unexpected status: %s %d", resp.Status, resp.StatusCode)
}
}
var encodingReaders = map[string]func(io.Reader) io.Reader{
"gzip": func(r io.Reader) io.Reader {
if gr, err := gzip.NewReader(r); err == nil {
return gr
}
return r
},
}
func decodeReader(r io.Reader, headers http.Header) io.Reader {
contentEncoding := strings.TrimSpace(headers.Get(headerContentEncoding))
fn, ok := encodingReaders[contentEncoding]
if ok {
return fn(r)
}
return r
}