diff --git a/services/httpd/response_writer.go b/services/httpd/response_writer.go index 55c124b15d..36191e8c93 100644 --- a/services/httpd/response_writer.go +++ b/services/httpd/response_writer.go @@ -28,15 +28,15 @@ func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter { switch r.Header.Get("Accept") { case "application/csv", "text/csv": w.Header().Add("Content-Type", "text/csv") - rw.formatter = &csvFormatter{statementID: -1, Writer: w} + rw.formatter = &csvFormatter{statementID: -1} case "application/x-msgpack": w.Header().Add("Content-Type", "application/x-msgpack") - rw.formatter = &msgpackFormatter{Writer: w} + rw.formatter = &msgpackFormatter{} case "application/json": fallthrough default: w.Header().Add("Content-Type", "application/json") - rw.formatter = &jsonFormatter{Pretty: pretty, Writer: w} + rw.formatter = &jsonFormatter{Pretty: pretty} } return rw } @@ -49,14 +49,27 @@ func WriteError(w ResponseWriter, err error) (int, error) { // responseWriter is an implementation of ResponseWriter. type responseWriter struct { formatter interface { - WriteResponse(resp Response) (int, error) + WriteResponse(w io.Writer, resp Response) error } http.ResponseWriter } +type bytesCountWriter struct { + w io.Writer + n int +} + +func (w *bytesCountWriter) Write(data []byte) (int, error) { + n, err := w.w.Write(data) + w.n += n + return n, err +} + // WriteResponse writes the response using the formatter. func (w *responseWriter) WriteResponse(resp Response) (int, error) { - return w.formatter.WriteResponse(resp) + writer := bytesCountWriter{w: w.ResponseWriter} + err := w.formatter.WriteResponse(&writer, resp) + return writer.n, err } // Flush flushes the ResponseWriter if it has a Flush() method. @@ -76,74 +89,69 @@ func (w *responseWriter) CloseNotify() <-chan bool { } type jsonFormatter struct { - io.Writer Pretty bool } -func (w *jsonFormatter) WriteResponse(resp Response) (n int, err error) { +func (f *jsonFormatter) WriteResponse(w io.Writer, resp Response) (err error) { var b []byte - if w.Pretty { + if f.Pretty { b, err = json.MarshalIndent(resp, "", " ") } else { b, err = json.Marshal(resp) } if err != nil { - n, err = io.WriteString(w, err.Error()) + _, err = io.WriteString(w, err.Error()) } else { - n, err = w.Write(b) + _, err = w.Write(b) } w.Write([]byte("\n")) - n++ - return n, err + return err } type csvFormatter struct { - io.Writer statementID int columns []string } -func (w *csvFormatter) WriteResponse(resp Response) (n int, err error) { +func (f *csvFormatter) WriteResponse(w io.Writer, resp Response) (err error) { csv := csv.NewWriter(w) if resp.Err != nil { csv.Write([]string{"error"}) csv.Write([]string{resp.Err.Error()}) csv.Flush() - return n, csv.Error() + return csv.Error() } for _, result := range resp.Results { - if result.StatementID != w.statementID { + if result.StatementID != f.statementID { // If there are no series in the result, skip past this result. if len(result.Series) == 0 { continue } // Set the statement id and print out a newline if this is not the first statement. - if w.statementID >= 0 { + if f.statementID >= 0 { // Flush the csv writer and write a newline. csv.Flush() if err := csv.Error(); err != nil { - return n, err + return err } - out, err := io.WriteString(w, "\n") - if err != nil { - return n, err + if _, err := io.WriteString(w, "\n"); err != nil { + return err } - n += out } - w.statementID = result.StatementID + f.statementID = result.StatementID // Print out the column headers from the first series. - w.columns = make([]string, 2+len(result.Series[0].Columns)) - w.columns[0] = "name" - w.columns[1] = "tags" - copy(w.columns[2:], result.Series[0].Columns) - if err := csv.Write(w.columns); err != nil { - return n, err + f.columns = make([]string, 2+len(result.Series[0].Columns)) + f.columns[0] = "name" + f.columns[1] = "tags" + copy(f.columns[2:], result.Series[0].Columns) + if err := csv.Write(f.columns); err != nil { + return err } } @@ -152,83 +160,79 @@ func (w *csvFormatter) WriteResponse(resp Response) (n int, err error) { // The columns have changed. Print a newline and reprint the header. csv.Flush() if err := csv.Error(); err != nil { - return n, err + return err } - out, err := io.WriteString(w, "\n") - if err != nil { - return n, err + if _, err := io.WriteString(w, "\n"); err != nil { + return err } - n += out - w.columns = make([]string, 2+len(row.Columns)) - w.columns[0] = "name" - w.columns[1] = "tags" - copy(w.columns[2:], row.Columns) - if err := csv.Write(w.columns); err != nil { - return n, err + f.columns = make([]string, 2+len(row.Columns)) + f.columns[0] = "name" + f.columns[1] = "tags" + copy(f.columns[2:], row.Columns) + if err := csv.Write(f.columns); err != nil { + return err } } - w.columns[0] = row.Name + f.columns[0] = row.Name if len(row.Tags) > 0 { - w.columns[1] = string(models.NewTags(row.Tags).HashKey()[1:]) + f.columns[1] = string(models.NewTags(row.Tags).HashKey()[1:]) } else { - w.columns[1] = "" + f.columns[1] = "" } for _, values := range row.Values { for i, value := range values { if value == nil { - w.columns[i+2] = "" + f.columns[i+2] = "" continue } switch v := value.(type) { case float64: - w.columns[i+2] = strconv.FormatFloat(v, 'f', -1, 64) + f.columns[i+2] = strconv.FormatFloat(v, 'f', -1, 64) case int64: - w.columns[i+2] = strconv.FormatInt(v, 10) + f.columns[i+2] = strconv.FormatInt(v, 10) case uint64: - w.columns[i+2] = strconv.FormatUint(v, 10) + f.columns[i+2] = strconv.FormatUint(v, 10) case string: - w.columns[i+2] = v + f.columns[i+2] = v case bool: if v { - w.columns[i+2] = "true" + f.columns[i+2] = "true" } else { - w.columns[i+2] = "false" + f.columns[i+2] = "false" } case time.Time: - w.columns[i+2] = strconv.FormatInt(v.UnixNano(), 10) + f.columns[i+2] = strconv.FormatInt(v.UnixNano(), 10) case *float64, *int64, *string, *bool: - w.columns[i+2] = "" + f.columns[i+2] = "" } } - csv.Write(w.columns) + csv.Write(f.columns) } } } csv.Flush() - return n, csv.Error() + return csv.Error() } -type msgpackFormatter struct { - io.Writer -} +type msgpackFormatter struct{} func (f *msgpackFormatter) ContentType() string { return "application/x-msgpack" } -func (f *msgpackFormatter) WriteResponse(resp Response) (n int, err error) { - enc := msgp.NewWriter(f.Writer) +func (f *msgpackFormatter) WriteResponse(w io.Writer, resp Response) (err error) { + enc := msgp.NewWriter(w) defer enc.Flush() enc.WriteMapHeader(1) if resp.Err != nil { enc.WriteString("error") enc.WriteString(resp.Err.Error()) - return 0, nil + return nil } else { enc.WriteString("results") enc.WriteArrayHeader(uint32(len(resp.Results))) @@ -311,7 +315,7 @@ func (f *msgpackFormatter) WriteResponse(resp Response) (n int, err error) { } } } - return 0, nil + return nil } func stringsEqual(a, b []string) bool { diff --git a/services/httpd/response_writer_test.go b/services/httpd/response_writer_test.go index eec1d20f26..ebdc93bb6d 100644 --- a/services/httpd/response_writer_test.go +++ b/services/httpd/response_writer_test.go @@ -28,7 +28,7 @@ func TestResponseWriter_CSV(t *testing.T) { w := httptest.NewRecorder() writer := httpd.NewResponseWriter(w, r) - writer.WriteResponse(httpd.Response{ + n, err := writer.WriteResponse(httpd.Response{ Results: []*query.Result{ { StatementID: 0, @@ -54,6 +54,9 @@ func TestResponseWriter_CSV(t *testing.T) { }, }, }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } if got, want := w.Body.String(), `name,tags,time,value cpu,"host=server01,region=uswest",10,2.5 @@ -65,6 +68,8 @@ cpu,"host=server01,region=uswest",60,false cpu,"host=server01,region=uswest",70,9223372036854775808 `; got != want { t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want) + } else if got, want := n, len(want); got != want { + t.Errorf("unexpected output length: got=%d want=%d", got, want) } } @@ -78,7 +83,7 @@ func TestResponseWriter_MessagePack(t *testing.T) { w := httptest.NewRecorder() writer := httpd.NewResponseWriter(w, r) - writer.WriteResponse(httpd.Response{ + _, err := writer.WriteResponse(httpd.Response{ Results: []*query.Result{ { StatementID: 0, @@ -103,6 +108,9 @@ func TestResponseWriter_MessagePack(t *testing.T) { }, }, }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } // The reader always reads times as time.Local so encode the expected response // as JSON and insert it into the expected values. @@ -125,8 +133,8 @@ func TestResponseWriter_MessagePack(t *testing.T) { t.Fatalf("unexpected error: %s", err) } want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values)) - if have := strings.TrimSpace(buf.String()); have != want { - t.Fatalf("unexpected output: %s != %s", have, want) + if got := strings.TrimSpace(buf.String()); got != want { + t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want) } }