Merge pull request #14485 from jsign/1.8_accept

fix: services/httpd: parse correctly Accept header with extra test cases
pull/17278/head
Jonathan A. Sternberg 2020-03-12 14:26:12 -05:00 committed by GitHub
commit 409de34abf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 297 additions and 136 deletions

80
services/httpd/accept.go Normal file
View File

@ -0,0 +1,80 @@
// This file is an adaptation of https://github.com/markusthoemmes/goautoneg.
// The copyright and license header are reproduced below.
//
// Copyright [yyyy] [name of copyright owner]
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// http://www.apache.org/licenses/LICENSE-2.0
package httpd
import (
"mime"
"sort"
"strconv"
"strings"
)
// accept is a structure to represent a clause in an HTTP Accept Header.
type accept struct {
Type, SubType string
Q float64
Params map[string]string
}
// parseAccept parses the given string as an Accept header as defined in
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1.
// Some rules are only loosely applied and might not be as strict as defined in the RFC.
func parseAccept(headers []string) []accept {
var res []accept
for _, header := range headers {
parts := strings.Split(header, ",")
for _, part := range parts {
mt, params, err := mime.ParseMediaType(part)
if err != nil {
continue
}
accept := accept{
Q: 1.0, // "[...] The default value is q=1"
Params: params,
}
// A media-type is defined as
// "*/*" | ( type "/" "*" ) | ( type "/" subtype )
types := strings.Split(mt, "/")
switch {
// This case is not defined in the spec keep it to mimic the original code.
case len(types) == 1 && types[0] == "*":
accept.Type = "*"
accept.SubType = "*"
case len(types) == 2:
accept.Type = types[0]
accept.SubType = types[1]
default:
continue
}
if qVal, ok := params["q"]; ok {
// A parsing failure will set Q to 0.
accept.Q, _ = strconv.ParseFloat(qVal, 64)
delete(params, "q")
}
res = append(res, accept)
}
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Q > res[j].Q
})
return res
}

View File

@ -20,27 +20,57 @@ type ResponseWriter interface {
http.ResponseWriter http.ResponseWriter
} }
type formatter interface {
WriteResponse(w io.Writer, resp Response) error
}
type supportedContentType struct {
full string
acceptType string
acceptSubType string
formatter func(pretty bool) formatter
}
var (
csvFormatFactory = func(pretty bool) formatter { return &csvFormatter{statementID: -1} }
msgpackFormatFactory = func(pretty bool) formatter { return &msgpackFormatter{} }
jsonFormatFactory = func(pretty bool) formatter { return &jsonFormatter{Pretty: pretty} }
contentTypes = []supportedContentType{
{full: "application/json", acceptType: "application", acceptSubType: "json", formatter: jsonFormatFactory},
{full: "application/csv", acceptType: "application", acceptSubType: "csv", formatter: csvFormatFactory},
{full: "text/csv", acceptType: "text", acceptSubType: "csv", formatter: csvFormatFactory},
{full: "application/x-msgpack", acceptType: "application", acceptSubType: "x-msgpack", formatter: msgpackFormatFactory},
}
defaultContentType = contentTypes[0]
)
// NewResponseWriter creates a new ResponseWriter based on the Accept header // NewResponseWriter creates a new ResponseWriter based on the Accept header
// in the request that wraps the ResponseWriter. // in the request that wraps the ResponseWriter.
func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter { func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter {
pretty := r.URL.Query().Get("pretty") == "true" pretty := r.URL.Query().Get("pretty") == "true"
rw := &responseWriter{ResponseWriter: w} rw := &responseWriter{ResponseWriter: w}
switch r.Header.Get("Accept") {
case "application/csv", "text/csv": acceptHeaders := parseAccept(r.Header["Accept"])
w.Header().Add("Content-Type", "text/csv") for _, accept := range acceptHeaders {
rw.formatter = &csvFormatter{statementID: -1} for _, ct := range contentTypes {
case "application/x-msgpack": if match(accept, ct) {
w.Header().Add("Content-Type", "application/x-msgpack") w.Header().Add("Content-Type", ct.full)
rw.formatter = &msgpackFormatter{} rw.formatter = ct.formatter(pretty)
case "application/json": return rw
fallthrough }
default: }
w.Header().Add("Content-Type", "application/json")
rw.formatter = &jsonFormatter{Pretty: pretty}
} }
w.Header().Add("Content-Type", defaultContentType.full)
rw.formatter = defaultContentType.formatter(pretty)
return rw return rw
} }
func match(ah accept, sct supportedContentType) bool {
return (ah.Type == "*" || ah.Type == sct.acceptType) &&
(ah.SubType == "*" || ah.SubType == sct.acceptSubType)
}
// WriteError is a convenience function for writing an error response to the ResponseWriter. // WriteError is a convenience function for writing an error response to the ResponseWriter.
func WriteError(w ResponseWriter, err error) (int, error) { func WriteError(w ResponseWriter, err error) (int, error) {
return w.WriteResponse(Response{Err: err}) return w.WriteResponse(Response{Err: err})

View File

@ -19,63 +19,79 @@ import (
) )
func TestResponseWriter_CSV(t *testing.T) { func TestResponseWriter_CSV(t *testing.T) {
header := make(http.Header) tableTest := []struct {
header.Set("Accept", "text/csv") header string
r := &http.Request{ }{
Header: header, {header: "*/csv"},
URL: &url.URL{}, {header: "text/*"},
{header: "text/csv"},
{header: "text/csv,application/json"},
{header: "text/csv;q=1,application/json"},
{header: "text/csv;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,text/csv;q=0.9"},
} }
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r) for _, testCase := range tableTest {
n, err := writer.WriteResponse(httpd.Response{ testCase := testCase
Results: []*query.Result{ t.Run(testCase.header, func(t *testing.T) {
{ t.Parallel()
StatementID: 0, header := make(http.Header)
Series: []*models.Row{ header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
n, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{ {
Name: "cpu", StatementID: 0,
Tags: map[string]string{ Series: []*models.Row{
"host": "server01", {
"region": "uswest", Name: "cpu",
}, Tags: map[string]string{
Columns: []string{"time", "value"}, "host": "server01",
Values: [][]interface{}{ "region": "uswest",
{time.Unix(0, 10), float64(2.5)}, },
{time.Unix(0, 20), int64(5)}, Columns: []string{"time", "value"},
{time.Unix(0, 30), nil}, Values: [][]interface{}{
{time.Unix(0, 40), "foobar"}, {time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 50), true}, {time.Unix(0, 20), int64(5)},
{time.Unix(0, 60), false}, {time.Unix(0, 30), nil},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, {time.Unix(0, 40), "foobar"},
}, {time.Unix(0, 50), true},
}, {time.Unix(0, 60), false},
{ {time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
Name: "cpu", },
Tags: map[string]string{ },
"host": "", {
"region": "", Name: "cpu",
}, Tags: map[string]string{
Columns: []string{"time", "value"}, "host": "",
Values: [][]interface{}{ "region": "",
{time.Unix(0, 10), float64(2.5)}, },
{time.Unix(0, 20), int64(5)}, Columns: []string{"time", "value"},
{time.Unix(0, 30), nil}, Values: [][]interface{}{
{time.Unix(0, 40), "foobar"}, {time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 50), true}, {time.Unix(0, 20), int64(5)},
{time.Unix(0, 60), false}, {time.Unix(0, 30), nil},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, {time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
}, },
}, },
}, },
}, })
}, if err != nil {
}) t.Fatalf("unexpected error: %s", err)
if err != nil { }
t.Fatalf("unexpected error: %s", err)
}
if got, want := w.Body.String(), `name,tags,time,value if got, want := w.Body.String(), `name,tags,time,value
cpu,"host=server01,region=uswest",10,2.5 cpu,"host=server01,region=uswest",10,2.5
cpu,"host=server01,region=uswest",20,5 cpu,"host=server01,region=uswest",20,5
cpu,"host=server01,region=uswest",30, cpu,"host=server01,region=uswest",30,
@ -91,99 +107,134 @@ cpu,,50,true
cpu,,60,false cpu,,60,false
cpu,,70,9223372036854775808 cpu,,70,9223372036854775808
`; got != want { `; got != want {
t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want) t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want)
} else if got, want := n, len(want); got != want { } else if got, want := n, len(want); got != want {
t.Errorf("unexpected output length: got=%d want=%d", got, want) t.Errorf("unexpected output length: got=%d want=%d", got, want)
}
})
} }
} }
func TestResponseWriter_MessagePack(t *testing.T) { func TestResponseWriter_MessagePack(t *testing.T) {
header := make(http.Header) tableTest := []struct {
header.Set("Accept", "application/x-msgpack") header string
r := &http.Request{ }{
Header: header, {header: "*/x-msgpack"},
URL: &url.URL{}, {header: "application/x-msgpack"},
{header: "application/x-msgpack,application/json"},
{header: "application/x-msgpack;q=1,application/json"},
{header: "application/x-msgpack;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,application/x-msgpack;q=0.9"},
} }
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r) for _, testCase := range tableTest {
_, err := writer.WriteResponse(httpd.Response{ testCase := testCase
Results: []*query.Result{ t.Run(testCase.header, func(t *testing.T) {
{ t.Parallel()
StatementID: 0, header := make(http.Header)
Series: []*models.Row{ header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
_, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{ {
Name: "cpu", StatementID: 0,
Tags: map[string]string{ Series: []*models.Row{
"host": "server01", {
}, Name: "cpu",
Columns: []string{"time", "value"}, Tags: map[string]string{
Values: [][]interface{}{ "host": "server01",
{time.Unix(0, 10), float64(2.5)}, },
{time.Unix(0, 20), int64(5)}, Columns: []string{"time", "value"},
{time.Unix(0, 30), nil}, Values: [][]interface{}{
{time.Unix(0, 40), "foobar"}, {time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 50), true}, {time.Unix(0, 20), int64(5)},
{time.Unix(0, 60), false}, {time.Unix(0, 30), nil},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, {time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
}, },
}, },
}, },
}, })
}, if err != nil {
}) t.Fatalf("unexpected error: %s", err)
if err != nil { }
t.Fatalf("unexpected error: %s", err)
}
// The reader always reads times as time.Local so encode the expected response // The reader always reads times as time.Local so encode the expected response
// as JSON and insert it into the expected values. // as JSON and insert it into the expected values.
values, err := json.Marshal([][]interface{}{ values, err := json.Marshal([][]interface{}{
{time.Unix(0, 10).Local(), float64(2.5)}, {time.Unix(0, 10).Local(), float64(2.5)},
{time.Unix(0, 20).Local(), int64(5)}, {time.Unix(0, 20).Local(), int64(5)},
{time.Unix(0, 30).Local(), nil}, {time.Unix(0, 30).Local(), nil},
{time.Unix(0, 40).Local(), "foobar"}, {time.Unix(0, 40).Local(), "foobar"},
{time.Unix(0, 50).Local(), true}, {time.Unix(0, 50).Local(), true},
{time.Unix(0, 60).Local(), false}, {time.Unix(0, 60).Local(), false},
{time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)}, {time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)},
}) })
if err != nil { if err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} }
reader := msgp.NewReader(w.Body) reader := msgp.NewReader(w.Body)
var buf bytes.Buffer var buf bytes.Buffer
if _, err := reader.WriteToJSON(&buf); err != nil { if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} }
want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values)) want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values))
if got := strings.TrimSpace(buf.String()); got != want { if got := strings.TrimSpace(buf.String()); got != want {
t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want) t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want)
}
})
} }
} }
func TestResponseWriter_MessagePack_Error(t *testing.T) { func TestResponseWriter_MessagePack_Error(t *testing.T) {
header := make(http.Header) tableTest := []struct {
header.Set("Accept", "application/x-msgpack") header string
r := &http.Request{ }{
Header: header, {header: "application/x-msgpack"},
URL: &url.URL{}, {header: "application/x-msgpack,application/json"},
{header: "application/x-msgpack;q=1,application/json"},
{header: "application/x-msgpack;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,application/x-msgpack;q=0.9"},
} }
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r) for _, testCase := range tableTest {
writer.WriteResponse(httpd.Response{ testCase := testCase
Err: fmt.Errorf("test error"), t.Run(testCase.header, func(t *testing.T) {
}) t.Parallel()
header := make(http.Header)
header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
reader := msgp.NewReader(w.Body) writer := httpd.NewResponseWriter(w, r)
var buf bytes.Buffer writer.WriteResponse(httpd.Response{
if _, err := reader.WriteToJSON(&buf); err != nil { Err: fmt.Errorf("test error"),
t.Fatalf("unexpected error: %s", err) })
}
want := fmt.Sprintf(`{"error":"test error"}`) reader := msgp.NewReader(w.Body)
if have := strings.TrimSpace(buf.String()); have != want { var buf bytes.Buffer
t.Fatalf("unexpected output: %s != %s", have, want) if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
want := fmt.Sprintf(`{"error":"test error"}`)
if have := strings.TrimSpace(buf.String()); have != want {
t.Fatalf("unexpected output: %s != %s", have, want)
}
})
} }
} }