package http import ( "bytes" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strconv" "strings" "testing" "github.com/influxdata/influxdb/logger" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) func TestLoggingMW(t *testing.T) { newDebugLogger := func(t *testing.T) (*zap.Logger, *bytes.Buffer) { t.Helper() var buf bytes.Buffer log, err := (&logger.Config{ Format: "auto", Level: zapcore.DebugLevel, }).New(&buf) if err != nil { t.Fatal(err) } return log, &buf } urlWithQueries := func(path string, queryPairs ...string) url.URL { u := url.URL{ Path: path, } params := u.Query() for i := 0; i < len(queryPairs)/2; i++ { k, v := queryPairs[i*2], queryPairs[i*2+1] params.Add(k, v) } return u } encodeBody := func(t *testing.T, k, v string) *bytes.Buffer { t.Helper() m := map[string]string{k: v} var buf bytes.Buffer err := json.NewEncoder(&buf).Encode(m) if err != nil { t.Fatal(err) } return &buf } getKVPair := func(s string) (string, string) { kv := strings.Split(s, "=") switch len(kv) { case 1: return kv[0], "" case 2: return kv[0], strings.TrimSuffix(kv[1], "\n") default: return "", "" } } echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var m map[string]string err := json.NewDecoder(r.Body).Decode(&m) if err != nil { w.WriteHeader(422) return } defer r.Body.Close() // set a non 200 status code here w.WriteHeader(202) _, err = w.Write([]byte("ack")) if err != nil { w.WriteHeader(500) return } }) teeReader := func(r *bytes.Buffer, w io.Writer) io.Reader { if r == nil { return nil } return io.TeeReader(r, w) } type testRun struct { name string method string path string queryPairs []string hasBody bool hideBody bool } testEndpoint := func(tt testRun) func(t *testing.T) { fn := func(t *testing.T) { t.Helper() log, buf := newDebugLogger(t) reqURL := urlWithQueries(tt.path, tt.queryPairs...) var body *bytes.Buffer if tt.hasBody { body = encodeBody(t, "bin", "shake") } var trackerBuf bytes.Buffer req := httptest.NewRequest(tt.method, reqURL.String(), teeReader(body, &trackerBuf)) rec := httptest.NewRecorder() LoggingMW(log)(echoHandler).ServeHTTP(rec, req) expected := map[string]string{ "method": tt.method, "host": "example.com", "path": reqURL.Path, "query": reqURL.RawQuery, "proto": "HTTP/1.1", "status_code": strconv.Itoa(rec.Code), "response_size": strconv.Itoa(rec.Body.Len()), "content_length": strconv.FormatInt(req.ContentLength, 10), } if tt.hasBody { expected["body"] = fmt.Sprintf("%q", trackerBuf.String()) } // skip first 4 pairs, is the base logger stuff for _, pair := range strings.Split(buf.String(), " ")[4:] { k, v := getKVPair(pair) switch k { case "took", "remote": if v == "" { t.Errorf("unexpected value(%q) for key(%q): expected=non empty string", v, k) } case "body": if tt.hideBody && v != "" { t.Errorf("expected body to be \"\" but got=%q", v) continue } fallthrough case "user_agent": default: if expectedV := expected[k]; expectedV != v { t.Errorf("unexpected value(%q) for key(%q): expected=%q", v, k, expectedV) } } } } return fn } t.Run("logs the http request", func(t *testing.T) { tests := []testRun{ { name: "GET", method: "GET", path: "/foo", queryPairs: []string{"dodgers", "are", "the", "terrible"}, }, { name: "POST", method: "POST", path: "/foo", queryPairs: []string{"bin", "shake"}, hasBody: true, }, { name: "PUT", method: "PUT", path: "/foo", queryPairs: []string{"ninja", "turtles"}, hasBody: true, }, { name: "PATCH", method: "PATCH", path: "/foo", queryPairs: []string{"peach", "daisy", "mario", "luigi"}, hasBody: true, }, } for _, tt := range tests { t.Run(tt.name, testEndpoint(tt)) } }) t.Run("does not log body for blacklisted routes", func(t *testing.T) { tests := []testRun{ { name: "signin", method: "POSTT", path: "/api/v2/signin", }, { name: "signout", method: "POST", path: "/api/v2/signout", }, { name: "me path", method: "POST", path: "/api/v2/me", }, { name: "me password path", method: "POST", path: "/api/v2/me/password", }, { name: "user password path", method: "POST", path: "/api/v2/users/user-id/password", }, { name: "write path", method: "POST", path: "/api/v2/write", }, { name: "orgs id secrets path", method: "PATCH", path: "/api/v2/orgs/org-id/secrets", }, { name: "orgs id secrets delete path", method: "POST", path: "/api/v2/orgs/org-id/secrets/delete", }, { name: "setup path", method: "POST", path: "/api/v2/setup", }, { name: "notifications endpoints path", method: "POST", path: "/api/v2/notificationEndpoints", }, { name: "notifications endpoints id path", method: "PUT", path: "/api/v2/notificationEndpoints/notification-id", }, } for _, tt := range tests { tt.hasBody = true tt.hideBody = true t.Run(tt.name, testEndpoint(tt)) } }) }