Merge pull request #14485 from jsign/1.8_accept

fix: services/httpd: parse correctly Accept header with extra test cases
pull/17278/head
Jonathan A. Sternberg 2020-03-12 14:26:12 -05:00 committed by GitHub
commit 409de34abf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 297 additions and 136 deletions

80
services/httpd/accept.go Normal file
View File

@ -0,0 +1,80 @@
// This file is an adaptation of https://github.com/markusthoemmes/goautoneg.
// The copyright and license header are reproduced below.
//
// Copyright [yyyy] [name of copyright owner]
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// http://www.apache.org/licenses/LICENSE-2.0
package httpd
import (
"mime"
"sort"
"strconv"
"strings"
)
// accept is a structure to represent a clause in an HTTP Accept Header.
type accept struct {
Type, SubType string
Q float64
Params map[string]string
}
// parseAccept parses the given string as an Accept header as defined in
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1.
// Some rules are only loosely applied and might not be as strict as defined in the RFC.
func parseAccept(headers []string) []accept {
var res []accept
for _, header := range headers {
parts := strings.Split(header, ",")
for _, part := range parts {
mt, params, err := mime.ParseMediaType(part)
if err != nil {
continue
}
accept := accept{
Q: 1.0, // "[...] The default value is q=1"
Params: params,
}
// A media-type is defined as
// "*/*" | ( type "/" "*" ) | ( type "/" subtype )
types := strings.Split(mt, "/")
switch {
// This case is not defined in the spec keep it to mimic the original code.
case len(types) == 1 && types[0] == "*":
accept.Type = "*"
accept.SubType = "*"
case len(types) == 2:
accept.Type = types[0]
accept.SubType = types[1]
default:
continue
}
if qVal, ok := params["q"]; ok {
// A parsing failure will set Q to 0.
accept.Q, _ = strconv.ParseFloat(qVal, 64)
delete(params, "q")
}
res = append(res, accept)
}
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Q > res[j].Q
})
return res
}

View File

@ -20,27 +20,57 @@ type ResponseWriter interface {
http.ResponseWriter
}
type formatter interface {
WriteResponse(w io.Writer, resp Response) error
}
type supportedContentType struct {
full string
acceptType string
acceptSubType string
formatter func(pretty bool) formatter
}
var (
csvFormatFactory = func(pretty bool) formatter { return &csvFormatter{statementID: -1} }
msgpackFormatFactory = func(pretty bool) formatter { return &msgpackFormatter{} }
jsonFormatFactory = func(pretty bool) formatter { return &jsonFormatter{Pretty: pretty} }
contentTypes = []supportedContentType{
{full: "application/json", acceptType: "application", acceptSubType: "json", formatter: jsonFormatFactory},
{full: "application/csv", acceptType: "application", acceptSubType: "csv", formatter: csvFormatFactory},
{full: "text/csv", acceptType: "text", acceptSubType: "csv", formatter: csvFormatFactory},
{full: "application/x-msgpack", acceptType: "application", acceptSubType: "x-msgpack", formatter: msgpackFormatFactory},
}
defaultContentType = contentTypes[0]
)
// NewResponseWriter creates a new ResponseWriter based on the Accept header
// in the request that wraps the ResponseWriter.
func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter {
pretty := r.URL.Query().Get("pretty") == "true"
rw := &responseWriter{ResponseWriter: w}
switch r.Header.Get("Accept") {
case "application/csv", "text/csv":
w.Header().Add("Content-Type", "text/csv")
rw.formatter = &csvFormatter{statementID: -1}
case "application/x-msgpack":
w.Header().Add("Content-Type", "application/x-msgpack")
rw.formatter = &msgpackFormatter{}
case "application/json":
fallthrough
default:
w.Header().Add("Content-Type", "application/json")
rw.formatter = &jsonFormatter{Pretty: pretty}
acceptHeaders := parseAccept(r.Header["Accept"])
for _, accept := range acceptHeaders {
for _, ct := range contentTypes {
if match(accept, ct) {
w.Header().Add("Content-Type", ct.full)
rw.formatter = ct.formatter(pretty)
return rw
}
}
}
w.Header().Add("Content-Type", defaultContentType.full)
rw.formatter = defaultContentType.formatter(pretty)
return rw
}
func match(ah accept, sct supportedContentType) bool {
return (ah.Type == "*" || ah.Type == sct.acceptType) &&
(ah.SubType == "*" || ah.SubType == sct.acceptSubType)
}
// WriteError is a convenience function for writing an error response to the ResponseWriter.
func WriteError(w ResponseWriter, err error) (int, error) {
return w.WriteResponse(Response{Err: err})

View File

@ -19,63 +19,79 @@ import (
)
func TestResponseWriter_CSV(t *testing.T) {
header := make(http.Header)
header.Set("Accept", "text/csv")
r := &http.Request{
Header: header,
URL: &url.URL{},
tableTest := []struct {
header string
}{
{header: "*/csv"},
{header: "text/*"},
{header: "text/csv"},
{header: "text/csv,application/json"},
{header: "text/csv;q=1,application/json"},
{header: "text/csv;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,text/csv;q=0.9"},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
n, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{
StatementID: 0,
Series: []*models.Row{
for _, testCase := range tableTest {
testCase := testCase
t.Run(testCase.header, func(t *testing.T) {
t.Parallel()
header := make(http.Header)
header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
n, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{
Name: "cpu",
Tags: map[string]string{
"host": "server01",
"region": "uswest",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
{
Name: "cpu",
Tags: map[string]string{
"host": "",
"region": "",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
StatementID: 0,
Series: []*models.Row{
{
Name: "cpu",
Tags: map[string]string{
"host": "server01",
"region": "uswest",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
{
Name: "cpu",
Tags: map[string]string{
"host": "",
"region": "",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
},
},
},
},
},
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if got, want := w.Body.String(), `name,tags,time,value
if got, want := w.Body.String(), `name,tags,time,value
cpu,"host=server01,region=uswest",10,2.5
cpu,"host=server01,region=uswest",20,5
cpu,"host=server01,region=uswest",30,
@ -91,99 +107,134 @@ cpu,,50,true
cpu,,60,false
cpu,,70,9223372036854775808
`; got != want {
t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want)
} else if got, want := n, len(want); got != want {
t.Errorf("unexpected output length: got=%d want=%d", got, want)
t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want)
} else if got, want := n, len(want); got != want {
t.Errorf("unexpected output length: got=%d want=%d", got, want)
}
})
}
}
func TestResponseWriter_MessagePack(t *testing.T) {
header := make(http.Header)
header.Set("Accept", "application/x-msgpack")
r := &http.Request{
Header: header,
URL: &url.URL{},
tableTest := []struct {
header string
}{
{header: "*/x-msgpack"},
{header: "application/x-msgpack"},
{header: "application/x-msgpack,application/json"},
{header: "application/x-msgpack;q=1,application/json"},
{header: "application/x-msgpack;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,application/x-msgpack;q=0.9"},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
_, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{
StatementID: 0,
Series: []*models.Row{
for _, testCase := range tableTest {
testCase := testCase
t.Run(testCase.header, func(t *testing.T) {
t.Parallel()
header := make(http.Header)
header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
_, err := writer.WriteResponse(httpd.Response{
Results: []*query.Result{
{
Name: "cpu",
Tags: map[string]string{
"host": "server01",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
StatementID: 0,
Series: []*models.Row{
{
Name: "cpu",
Tags: map[string]string{
"host": "server01",
},
Columns: []string{"time", "value"},
Values: [][]interface{}{
{time.Unix(0, 10), float64(2.5)},
{time.Unix(0, 20), int64(5)},
{time.Unix(0, 30), nil},
{time.Unix(0, 40), "foobar"},
{time.Unix(0, 50), true},
{time.Unix(0, 60), false},
{time.Unix(0, 70), uint64(math.MaxInt64 + 1)},
},
},
},
},
},
},
},
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
// The reader always reads times as time.Local so encode the expected response
// as JSON and insert it into the expected values.
values, err := json.Marshal([][]interface{}{
{time.Unix(0, 10).Local(), float64(2.5)},
{time.Unix(0, 20).Local(), int64(5)},
{time.Unix(0, 30).Local(), nil},
{time.Unix(0, 40).Local(), "foobar"},
{time.Unix(0, 50).Local(), true},
{time.Unix(0, 60).Local(), false},
{time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)},
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
// The reader always reads times as time.Local so encode the expected response
// as JSON and insert it into the expected values.
values, err := json.Marshal([][]interface{}{
{time.Unix(0, 10).Local(), float64(2.5)},
{time.Unix(0, 20).Local(), int64(5)},
{time.Unix(0, 30).Local(), nil},
{time.Unix(0, 40).Local(), "foobar"},
{time.Unix(0, 50).Local(), true},
{time.Unix(0, 60).Local(), false},
{time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)},
})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
reader := msgp.NewReader(w.Body)
var buf bytes.Buffer
if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values))
if got := strings.TrimSpace(buf.String()); got != want {
t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want)
reader := msgp.NewReader(w.Body)
var buf bytes.Buffer
if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values))
if got := strings.TrimSpace(buf.String()); got != want {
t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want)
}
})
}
}
func TestResponseWriter_MessagePack_Error(t *testing.T) {
header := make(http.Header)
header.Set("Accept", "application/x-msgpack")
r := &http.Request{
Header: header,
URL: &url.URL{},
tableTest := []struct {
header string
}{
{header: "application/x-msgpack"},
{header: "application/x-msgpack,application/json"},
{header: "application/x-msgpack;q=1,application/json"},
{header: "application/x-msgpack;q=0.9,application/json;q=0.8"},
{header: "application/json;q=0.8,application/x-msgpack;q=0.9"},
}
w := httptest.NewRecorder()
writer := httpd.NewResponseWriter(w, r)
writer.WriteResponse(httpd.Response{
Err: fmt.Errorf("test error"),
})
for _, testCase := range tableTest {
testCase := testCase
t.Run(testCase.header, func(t *testing.T) {
t.Parallel()
header := make(http.Header)
header.Set("Accept", testCase.header)
r := &http.Request{
Header: header,
URL: &url.URL{},
}
w := httptest.NewRecorder()
reader := msgp.NewReader(w.Body)
var buf bytes.Buffer
if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
want := fmt.Sprintf(`{"error":"test error"}`)
if have := strings.TrimSpace(buf.String()); have != want {
t.Fatalf("unexpected output: %s != %s", have, want)
writer := httpd.NewResponseWriter(w, r)
writer.WriteResponse(httpd.Response{
Err: fmt.Errorf("test error"),
})
reader := msgp.NewReader(w.Body)
var buf bytes.Buffer
if _, err := reader.WriteToJSON(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
want := fmt.Sprintf(`{"error":"test error"}`)
if have := strings.TrimSpace(buf.String()); have != want {
t.Fatalf("unexpected output: %s != %s", have, want)
}
})
}
}