fix(logging): blacklist endpoints with sensitive data from logging body
parent
5a546d5827
commit
dbe0103d92
|
@ -712,7 +712,7 @@ func (m *Launcher) run(ctx context.Context) (err error) {
|
|||
h := http.NewHandlerFromRegistry("platform", m.reg)
|
||||
h.Handler = platformHandler
|
||||
if logconf.Level == zap.DebugLevel {
|
||||
h.Handler = http.HTTPLoggingMW(httpLogger)(h.Handler)
|
||||
h.Handler = http.LoggingMW(httpLogger)(h.Handler)
|
||||
}
|
||||
h.Logger = httpLogger
|
||||
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func LoggingMW(logger *zap.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
srw := &statusResponseWriter{
|
||||
ResponseWriter: w,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
r.Body = &bodyEchoer{
|
||||
rc: r.Body,
|
||||
teedR: io.TeeReader(r.Body, &buf),
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
errField := zap.Skip()
|
||||
if errStr := w.Header().Get(PlatformErrorCodeHeader); errStr != "" {
|
||||
errField = zap.Error(errors.New(errStr))
|
||||
}
|
||||
|
||||
errReferenceField := zap.Skip()
|
||||
if errReference := w.Header().Get(PlatformErrorCodeHeader); errReference != "" {
|
||||
errReferenceField = zap.String("error_code", PlatformErrorCodeHeader)
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("method", r.Method),
|
||||
zap.String("host", r.Host),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("query", r.URL.Query().Encode()),
|
||||
zap.String("proto", r.Proto),
|
||||
zap.Int("status_code", srw.code()),
|
||||
zap.Int("response_size", srw.responseBytes),
|
||||
zap.Int64("content_length", r.ContentLength),
|
||||
zap.String("referrer", r.Referer()),
|
||||
zap.String("remote", r.RemoteAddr),
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
zap.Duration("took", time.Since(start)),
|
||||
errField,
|
||||
errReferenceField,
|
||||
}
|
||||
|
||||
invalidMethodFn, ok := mapURLPath(r.URL.Path)
|
||||
if !ok || !invalidMethodFn(r.Method) {
|
||||
fields = append(fields, zap.ByteString("body", buf.Bytes()))
|
||||
}
|
||||
|
||||
logger.Debug("Request", fields...)
|
||||
}(time.Now())
|
||||
|
||||
next.ServeHTTP(srw, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
type isValidMethodFn func(method string) bool
|
||||
|
||||
func mapURLPath(rawPath string) (isValidMethodFn, bool) {
|
||||
if fn, ok := blacklistEndpoints[rawPath]; ok {
|
||||
return fn, true
|
||||
}
|
||||
|
||||
shiftPath := func(p string) (head, tail string) {
|
||||
p = path.Clean("/" + p)
|
||||
i := strings.Index(p[1:], "/") + 1
|
||||
if i <= 0 {
|
||||
return p[1:], "/"
|
||||
}
|
||||
return p[1:i], p[i:]
|
||||
}
|
||||
|
||||
// ugh, should probably make this whole operation use a trie
|
||||
partsMatch := func(raw, source string) bool {
|
||||
return raw == source || (strings.HasPrefix(source, ":") && raw != "")
|
||||
}
|
||||
|
||||
compareRawSourceURLs := func(raw, source string) bool {
|
||||
sourceHead, sourceTail := shiftPath(source)
|
||||
for rawHead, rawTail := shiftPath(rawPath); rawHead != ""; {
|
||||
if !partsMatch(rawHead, sourceHead) {
|
||||
return false
|
||||
}
|
||||
rawHead, rawTail = shiftPath(rawTail)
|
||||
sourceHead, sourceTail = shiftPath(sourceTail)
|
||||
}
|
||||
return sourceHead == ""
|
||||
}
|
||||
|
||||
for sourcePath, fn := range blacklistEndpoints {
|
||||
match := compareRawSourceURLs(rawPath, sourcePath)
|
||||
if match {
|
||||
return fn, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func ignoreMethod(ignoredMethods ...string) isValidMethodFn {
|
||||
if len(ignoredMethods) == 0 {
|
||||
return func(string) bool { return true }
|
||||
}
|
||||
|
||||
ignoreMap := make(map[string]bool)
|
||||
for _, method := range ignoredMethods {
|
||||
ignoreMap[method] = true
|
||||
}
|
||||
|
||||
return func(method string) bool {
|
||||
return ignoreMap[method]
|
||||
}
|
||||
}
|
||||
|
||||
var blacklistEndpoints = map[string]isValidMethodFn{
|
||||
"/api/v2/signin": ignoreMethod(),
|
||||
"/api/v2/signout": ignoreMethod(),
|
||||
mePath: ignoreMethod(),
|
||||
mePasswordPath: ignoreMethod(),
|
||||
usersPasswordPath: ignoreMethod(),
|
||||
writePath: ignoreMethod("POST"),
|
||||
organizationsIDSecretsPath: ignoreMethod("PATCH"),
|
||||
organizationsIDSecretsDeletePath: ignoreMethod("POST"),
|
||||
setupPath: ignoreMethod("POST"),
|
||||
notificationEndpointsPath: ignoreMethod("POST"),
|
||||
notificationEndpointsIDPath: ignoreMethod("PUT"),
|
||||
}
|
||||
|
||||
type bodyEchoer struct {
|
||||
rc io.ReadCloser
|
||||
teedR io.Reader
|
||||
}
|
||||
|
||||
func (b *bodyEchoer) Read(p []byte) (int, error) {
|
||||
return b.teedR.Read(p)
|
||||
}
|
||||
|
||||
func (b *bodyEchoer) Close() error {
|
||||
return b.rc.Close()
|
||||
}
|
|
@ -0,0 +1,266 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/logger"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestLoggingMW(t *testing.T) {
|
||||
newDebugLogger := func(t *testing.T) (*zap.Logger, *bytes.Buffer) {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
log, err := (&logger.Config{
|
||||
Format: "auto",
|
||||
Level: zapcore.DebugLevel,
|
||||
}).New(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return log, &buf
|
||||
}
|
||||
|
||||
urlWithQueries := func(path string, queryPairs ...string) url.URL {
|
||||
u := url.URL{
|
||||
Path: path,
|
||||
}
|
||||
params := u.Query()
|
||||
for i := 0; i < len(queryPairs)/2; i++ {
|
||||
k, v := queryPairs[i*2], queryPairs[i*2+1]
|
||||
params.Add(k, v)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
encodeBody := func(t *testing.T, k, v string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
m := map[string]string{k: v}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := json.NewEncoder(&buf).Encode(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &buf
|
||||
}
|
||||
|
||||
getKVPair := func(s string) (string, string) {
|
||||
kv := strings.Split(s, "=")
|
||||
switch len(kv) {
|
||||
case 1:
|
||||
return kv[0], ""
|
||||
case 2:
|
||||
return kv[0], strings.TrimSuffix(kv[1], "\n")
|
||||
default:
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var m map[string]string
|
||||
err := json.NewDecoder(r.Body).Decode(&m)
|
||||
if err != nil {
|
||||
w.WriteHeader(422)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// set a non 200 status code here
|
||||
w.WriteHeader(202)
|
||||
|
||||
_, err = w.Write([]byte("ack"))
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
teeReader := func(r *bytes.Buffer, w io.Writer) io.Reader {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return io.TeeReader(r, w)
|
||||
}
|
||||
|
||||
type testRun struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
queryPairs []string
|
||||
hasBody bool
|
||||
hideBody bool
|
||||
}
|
||||
|
||||
testEndpoint := func(tt testRun) func(t *testing.T) {
|
||||
fn := func(t *testing.T) {
|
||||
log, buf := newDebugLogger(t)
|
||||
|
||||
reqURL := urlWithQueries(tt.path, tt.queryPairs...)
|
||||
var body *bytes.Buffer
|
||||
if tt.hasBody {
|
||||
body = encodeBody(t, "bin", "shake")
|
||||
}
|
||||
|
||||
var trackerBuf bytes.Buffer
|
||||
req := httptest.NewRequest(tt.method, reqURL.String(), teeReader(body, &trackerBuf))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
LoggingMW(log)(echoHandler).ServeHTTP(rec, req)
|
||||
|
||||
expected := map[string]string{
|
||||
"method": tt.method,
|
||||
"host": "example.com",
|
||||
"path": reqURL.Path,
|
||||
"query": reqURL.RawQuery,
|
||||
"proto": "HTTP/1.1",
|
||||
"status_code": strconv.Itoa(rec.Code),
|
||||
"response_size": strconv.Itoa(rec.Body.Len()),
|
||||
"content_length": strconv.FormatInt(req.ContentLength, 10),
|
||||
}
|
||||
if tt.hasBody {
|
||||
expected["body"] = fmt.Sprintf("%q", trackerBuf.String())
|
||||
}
|
||||
|
||||
// skip first 4 pairs, is the base logger stuff
|
||||
for _, pair := range strings.Split(buf.String(), " ")[4:] {
|
||||
k, v := getKVPair(pair)
|
||||
switch k {
|
||||
case "took", "remote":
|
||||
if v == "" {
|
||||
t.Errorf("unexpected value(%q) for key(%q): expected=non empty string", v, k)
|
||||
}
|
||||
case "body":
|
||||
if tt.hideBody && v != "" {
|
||||
t.Errorf("expected body to be \"\" but got=%q", v)
|
||||
continue
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
if expectedV := expected[k]; expectedV != v {
|
||||
t.Errorf("unexpected value(%q) for key(%q): expected=%q", v, k, expectedV)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fn
|
||||
}
|
||||
|
||||
t.Run("logs the http request", func(t *testing.T) {
|
||||
tests := []testRun{
|
||||
{
|
||||
name: "GET",
|
||||
method: "GET",
|
||||
path: "/foo",
|
||||
queryPairs: []string{"dodgers", "are", "the", "terrible"},
|
||||
},
|
||||
{
|
||||
name: "POST",
|
||||
method: "POST",
|
||||
path: "/foo",
|
||||
queryPairs: []string{"bin", "shake"},
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
name: "PUT",
|
||||
method: "PUT",
|
||||
path: "/foo",
|
||||
queryPairs: []string{"ninja", "turtles"},
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
name: "PATCH",
|
||||
method: "PATCH",
|
||||
path: "/foo",
|
||||
queryPairs: []string{"peach", "daisy", "mario", "luigi"},
|
||||
hasBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, testEndpoint(tt))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not log body for blacklisted routes", func(t *testing.T) {
|
||||
tests := []testRun{
|
||||
{
|
||||
name: "signin",
|
||||
method: "POSTT",
|
||||
path: "/api/v2/signin",
|
||||
},
|
||||
{
|
||||
name: "signout",
|
||||
method: "POST",
|
||||
path: "/api/v2/signout",
|
||||
},
|
||||
{
|
||||
name: "me path",
|
||||
method: "POST",
|
||||
path: "/api/v2/me",
|
||||
},
|
||||
{
|
||||
name: "me password path",
|
||||
method: "POST",
|
||||
path: "/api/v2/me/password",
|
||||
},
|
||||
{
|
||||
name: "user password path",
|
||||
method: "POST",
|
||||
path: "/api/v2/users/user-id/password",
|
||||
},
|
||||
{
|
||||
name: "write path",
|
||||
method: "POST",
|
||||
path: "/api/v2/write",
|
||||
},
|
||||
{
|
||||
name: "orgs id secrets path",
|
||||
method: "PATCH",
|
||||
path: "/api/v2/orgs/org-id/secrets",
|
||||
},
|
||||
{
|
||||
name: "orgs id secrets delete path",
|
||||
method: "POST",
|
||||
path: "/api/v2/orgs/org-id/secrets/delete",
|
||||
},
|
||||
{
|
||||
name: "setup path",
|
||||
method: "POST",
|
||||
path: "/api/v2/setup",
|
||||
},
|
||||
{
|
||||
name: "notifications endpoints path",
|
||||
method: "POST",
|
||||
path: "/api/v2/notificationEndpoints",
|
||||
},
|
||||
{
|
||||
name: "notifications endpoints id path",
|
||||
method: "PUT",
|
||||
path: "/api/v2/notificationEndpoints/notification-id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt.hasBody = true
|
||||
tt.hideBody = true
|
||||
t.Run(tt.name, testEndpoint(tt))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
|
@ -1,15 +1,10 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PlatformHandler is a collection of all the service handlers.
|
||||
|
@ -27,68 +22,6 @@ func setCORSResponseHeaders(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
type bodyEchoer struct {
|
||||
rc io.ReadCloser
|
||||
teedR io.Reader
|
||||
}
|
||||
|
||||
func (b *bodyEchoer) Read(p []byte) (int, error) {
|
||||
return b.teedR.Read(p)
|
||||
}
|
||||
|
||||
func (b *bodyEchoer) Close() error {
|
||||
return b.rc.Close()
|
||||
}
|
||||
|
||||
func HTTPLoggingMW(logger *zap.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
srw := &statusResponseWriter{
|
||||
ResponseWriter: w,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
r.Body = &bodyEchoer{
|
||||
rc: r.Body,
|
||||
teedR: io.TeeReader(r.Body, &buf),
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
errField := zap.Skip()
|
||||
if errStr := w.Header().Get(PlatformErrorCodeHeader); errStr != "" {
|
||||
errField = zap.Error(errors.New(errStr))
|
||||
}
|
||||
|
||||
errReferenceField := zap.Skip()
|
||||
if errReference := w.Header().Get(PlatformErrorCodeHeader); errReference != "" {
|
||||
errReferenceField = zap.String("error_code", PlatformErrorCodeHeader)
|
||||
}
|
||||
|
||||
logger.Debug(
|
||||
"Request",
|
||||
zap.String("method", r.Method),
|
||||
zap.String("host", r.Host),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("query", r.URL.Query().Encode()),
|
||||
zap.String("proto", r.Proto),
|
||||
zap.Int("status_code", srw.code()),
|
||||
zap.Int("response_size", srw.responseBytes),
|
||||
zap.Int64("content_length", r.ContentLength),
|
||||
zap.String("referrer", r.Referer()),
|
||||
zap.String("remote", r.RemoteAddr),
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
zap.ByteString("body", buf.Bytes()),
|
||||
zap.Duration("took", time.Since(start)),
|
||||
errField,
|
||||
errReferenceField,
|
||||
)
|
||||
}(time.Now())
|
||||
next.ServeHTTP(srw, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
// NewPlatformHandler returns a platform handler that serves the API and associated assets.
|
||||
func NewPlatformHandler(b *APIBackend) *PlatformHandler {
|
||||
h := NewAuthenticationHandler(b.HTTPErrorHandler)
|
||||
|
|
Loading…
Reference in New Issue