From 78724e5c20e6189f86f97bbb110c9ca050101624 Mon Sep 17 00:00:00 2001 From: Sam Arnold Date: Tue, 30 Mar 2021 14:09:04 -0300 Subject: [PATCH] chore: Add kit (#21086) * chore: pull in unchanged kit from v2 * chore: remove v2 from import paths * chore: update module paths and go.mod for kit * chore: remove kit/cli again, not needed in 1.x --- go.mod | 6 + go.sum | 16 +- kit/check/check.go | 238 +++++++++ kit/check/check_test.go | 535 +++++++++++++++++++ kit/check/helpers.go | 73 +++ kit/check/response.go | 39 ++ kit/errors/errors.go | 107 ++++ kit/errors/list.go | 59 ++ kit/feature/_codegen/main.go | 271 ++++++++++ kit/feature/doc.go | 75 +++ kit/feature/feature.go | 140 +++++ kit/feature/feature_test.go | 185 +++++++ kit/feature/flag.go | 216 ++++++++ kit/feature/http_proxy.go | 73 +++ kit/feature/http_proxy_test.go | 137 +++++ kit/feature/list.go | 249 +++++++++ kit/feature/middleware.go | 70 +++ kit/feature/middleware_test.go | 47 ++ kit/feature/override/override.go | 83 +++ kit/feature/override/override_test.go | 183 +++++++ kit/io/limited_read_closer.go | 59 ++ kit/io/limited_read_closer_test.go | 87 +++ kit/metric/client.go | 152 ++++++ kit/metric/metrics_options.go | 84 +++ kit/platform/errors/error.go | 9 + kit/platform/errors/errors.go | 264 +++++++++ kit/platform/errors/errors.md | 86 +++ kit/platform/errors/errors_test.go | 300 +++++++++++ kit/platform/id.go | 145 +++++ kit/platform/id_test.go | 259 +++++++++ kit/prom/example_test.go | 87 +++ kit/prom/promtest/promtest.go | 146 +++++ kit/prom/promtest/promtest_test.go | 210 ++++++++ kit/prom/registry.go | 59 ++ kit/prom/registry_test.go | 60 +++ kit/signals/context.go | 31 ++ kit/signals/context_test.go | 79 +++ kit/tracing/testing/testing.go | 24 + kit/tracing/tracing.go | 225 ++++++++ kit/tracing/tracing_test.go | 333 ++++++++++++ kit/transport/http/api.go | 267 +++++++++ kit/transport/http/api_test.go | 309 +++++++++++ kit/transport/http/error_handler.go | 188 +++++++ kit/transport/http/error_handler_test.go | 56 ++ kit/transport/http/feature_controller.go | 39 ++ kit/transport/http/handler.go | 11 + kit/transport/http/middleware.go | 182 +++++++ kit/transport/http/middleware_test.go | 95 ++++ kit/transport/http/status_response_writer.go | 58 ++ pkg/testttp/http.go | 195 +++++++ pkg/testttp/http_test.go | 174 ++++++ 51 files changed, 7073 insertions(+), 2 deletions(-) create mode 100644 kit/check/check.go create mode 100644 kit/check/check_test.go create mode 100644 kit/check/helpers.go create mode 100644 kit/check/response.go create mode 100644 kit/errors/errors.go create mode 100644 kit/errors/list.go create mode 100644 kit/feature/_codegen/main.go create mode 100644 kit/feature/doc.go create mode 100644 kit/feature/feature.go create mode 100644 kit/feature/feature_test.go create mode 100644 kit/feature/flag.go create mode 100644 kit/feature/http_proxy.go create mode 100644 kit/feature/http_proxy_test.go create mode 100644 kit/feature/list.go create mode 100644 kit/feature/middleware.go create mode 100644 kit/feature/middleware_test.go create mode 100644 kit/feature/override/override.go create mode 100644 kit/feature/override/override_test.go create mode 100644 kit/io/limited_read_closer.go create mode 100644 kit/io/limited_read_closer_test.go create mode 100644 kit/metric/client.go create mode 100644 kit/metric/metrics_options.go create mode 100644 kit/platform/errors/error.go create mode 100644 kit/platform/errors/errors.go create mode 100644 kit/platform/errors/errors.md create mode 100644 kit/platform/errors/errors_test.go create mode 100644 kit/platform/id.go create mode 100644 kit/platform/id_test.go create mode 100644 kit/prom/example_test.go create mode 100644 kit/prom/promtest/promtest.go create mode 100644 kit/prom/promtest/promtest_test.go create mode 100644 kit/prom/registry.go create mode 100644 kit/prom/registry_test.go create mode 100644 kit/signals/context.go create mode 100644 kit/signals/context_test.go create mode 100644 kit/tracing/testing/testing.go create mode 100644 kit/tracing/tracing.go create mode 100644 kit/tracing/tracing_test.go create mode 100644 kit/transport/http/api.go create mode 100644 kit/transport/http/api_test.go create mode 100644 kit/transport/http/error_handler.go create mode 100644 kit/transport/http/error_handler_test.go create mode 100644 kit/transport/http/feature_controller.go create mode 100644 kit/transport/http/handler.go create mode 100644 kit/transport/http/middleware.go create mode 100644 kit/transport/http/middleware_test.go create mode 100644 kit/transport/http/status_response_writer.go create mode 100644 pkg/testttp/http.go create mode 100644 pkg/testttp/http_test.go diff --git a/go.mod b/go.mod index 360b96f304..d37d63f537 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,12 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 github.com/dgryski/go-bitstream v0.0.0-20180413035011-3522498ce2c8 + github.com/go-chi/chi v4.1.0+incompatible github.com/gogo/protobuf v1.3.1 github.com/golang/snappy v0.0.1 github.com/google/go-cmp v0.5.0 github.com/influxdata/flux v0.111.0 + github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69 github.com/influxdata/influxql v1.1.1-0.20210223160523-b6ab99450c93 github.com/influxdata/pkg-config v0.2.7 github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6 @@ -24,16 +26,20 @@ require ( github.com/jwilder/encoding v0.0.0-20170811194829-b4e1701a28ef github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada github.com/mattn/go-isatty v0.0.12 + github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5 github.com/opentracing/opentracing-go v1.1.0 github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.5.1 + github.com/prometheus/client_model v0.2.0 + github.com/prometheus/common v0.9.1 github.com/prometheus/prometheus v0.0.0-20200609090129-a6600f564e3c github.com/retailnext/hllpp v1.0.1-0.20180308014038-101a6d2f8b52 github.com/spf13/cast v1.3.0 github.com/spf13/cobra v0.0.3 github.com/stretchr/testify v1.5.1 github.com/tinylib/msgp v1.1.0 + github.com/uber/jaeger-client-go v2.23.0+incompatible github.com/willf/bitset v1.1.9 // indirect github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 go.uber.org/zap v1.14.1 diff --git a/go.sum b/go.sum index 00840eebc2..cf0c246f73 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,7 @@ github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4 github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -199,6 +200,7 @@ github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15 h1:nLPjjvpUAODOR github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15/go.mod h1:tPg4cp4nseejPd+UKxtCVQ2hUxNTZ7qQZJa7CLriIeo= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= +github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -208,6 +210,8 @@ github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd h1:r04M github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31 h1:gclg6gY70GLy3PbkQ1AERPfmLMMagS60DKF78eWwLn8= github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= +github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w= +github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -434,6 +438,8 @@ github.com/influxdata/flux v0.65.0 h1:57tk1Oo4gpGIDbV12vUAPCMtLtThhaXzub1XRIuqv6 github.com/influxdata/flux v0.65.0/go.mod h1:BwN2XG2lMszOoquQaFdPET8FRQfrXiZsWmcMO9rkaVY= github.com/influxdata/flux v0.111.0 h1:27CNz0SbEofD9NzdwcdxRwGmuVSDSisVq4dOceB/KF0= github.com/influxdata/flux v0.111.0/go.mod h1:3TJtvbm/Kwuo5/PEo5P6HUzwVg4bXWkb2wPQHPtQdlU= +github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69 h1:WQsmW0fXO4ZE/lFGIE84G6rIV5SJN3P3sjIXAP1a8eU= +github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69/go.mod h1:pwymjR6SrP3gD3pRj9RJwdl1j5s3doEEV8gS4X9qSzA= github.com/influxdata/influxdb v1.8.0/go.mod h1:SIzcnsjaHRFpmlxpJ4S3NT64qtEKYweNTUMb/vh0OMQ= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/influxdata/influxql v1.1.0/go.mod h1:KpVI7okXjK6PRi3Z5B+mtKZli+R1DnZgb3N+tzevNgo= @@ -447,7 +453,6 @@ github.com/influxdata/pkg-config v0.2.7/go.mod h1:EMS7Ll0S4qkzDk53XS3Z72/egBsPIn github.com/influxdata/promql/v2 v2.12.0/go.mod h1:fxOPu+DY0bqCTCECchSRtWfc+0X19ybifQhZoQNF5D8= github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6 h1:UzJnB7VRL4PSkUJHwsyzseGOmrO/r4yA+AuxGJxiZmA= github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6/go.mod h1:bSgUQ7q5ZLSO+bKBGqJiCBGAl+9DxyW63zLTujjUlOE= -github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9 h1:MHTrDWmQpHq/hkq+7cw9oYAt2PqUw52TZazRA0N7PGE= github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9/go.mod h1:Js0mqiSBE6Ffsg94weZZ2c+v/ciT8QRHFOap7EKDrR0= github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b h1:i44CesU68ZBRvtCjBi3QSosCIKrjmMbYlQMFAwVLds4= github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b/go.mod h1:Z0kXnxzbTC2qrx4NaIzYkE1k66+6oEDQTvL95hQFh5Y= @@ -540,6 +545,8 @@ github.com/miekg/dns v1.1.22/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg= github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5 h1:pXqZHmHOz6LN+zbbUgqyGgAWRnnZEI40IzG3tMsXcSI= +github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5/go.mod h1:JWhYAp2EXqUtsxTKdeGlY8Wp44M7VxThC9FEoNGi2IE= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= @@ -550,6 +557,7 @@ github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS4 github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.2.2 h1:dxe5oCinTXiTIcfgmZecdCzPmAJKd46KsCWc35r0TV4= github.com/mitchellh/mapstructure v1.2.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -596,9 +604,9 @@ github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/paulbellamy/ratecounter v0.2.0 h1:2L/RhJq+HA8gBQImDXtLPrDXK5qAj6ozWVK/zFXVJGs= github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChlfo5C6hzIHwPqfFE= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfSeg= github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= @@ -690,6 +698,7 @@ github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4k github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= @@ -722,7 +731,9 @@ github.com/uber-go/tally v3.3.15+incompatible h1:9hLSgNBP28CjIaDmAuRTq9qV+UZY+9P github.com/uber-go/tally v3.3.15+incompatible/go.mod h1:YDTIBxdXyOU/sCWilKB4bgyufu1cEi0jdVnRdxvjnmU= github.com/uber/athenadriver v1.1.4 h1:k6k0RBeXjR7oZ8NO557MsRw3eX1cc/9B0GNx+W9eHiQ= github.com/uber/athenadriver v1.1.4/go.mod h1:tQjho4NzXw55LGfSZEcETuYydpY1vtmixUabHkC1K/E= +github.com/uber/jaeger-client-go v2.23.0+incompatible h1:o2g11IUBdEsSZVzF3k7+bahLmxRP/dbOoW4zQ30UlKE= github.com/uber/jaeger-client-go v2.23.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= +github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= github.com/uber/jaeger-lib v2.2.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= @@ -1101,6 +1112,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/kit/check/check.go b/kit/check/check.go new file mode 100644 index 0000000000..0e75c5f0ad --- /dev/null +++ b/kit/check/check.go @@ -0,0 +1,238 @@ +// Package check standardizes /health and /ready endpoints. +// This allows you to easily know when your server is ready and healthy. +package check + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sort" + "sync" +) + +// Status string to indicate the overall status of the check. +type Status string + +const ( + // StatusFail indicates a specific check has failed. + StatusFail Status = "fail" + // StatusPass indicates a specific check has passed. + StatusPass Status = "pass" + + // DefaultCheckName is the name of the default checker. + DefaultCheckName = "internal" +) + +// Check wraps a map of service names to status checkers. +type Check struct { + healthChecks []Checker + readyChecks []Checker + healthOverride override + readyOverride override + + passthroughHandler http.Handler +} + +// Checker indicates a service whose health can be checked. +type Checker interface { + Check(ctx context.Context) Response +} + +// NewCheck returns a Health with a default checker. +func NewCheck() *Check { + ch := &Check{} + ch.healthOverride.disable() + ch.readyOverride.disable() + return ch +} + +// AddHealthCheck adds the check to the list of ready checks. +// If c is a NamedChecker, the name will be added. +func (c *Check) AddHealthCheck(check Checker) { + if nc, ok := check.(NamedChecker); ok { + c.healthChecks = append(c.healthChecks, Named(nc.CheckName(), nc)) + } else { + c.healthChecks = append(c.healthChecks, check) + } +} + +// AddReadyCheck adds the check to the list of ready checks. +// If c is a NamedChecker, the name will be added. +func (c *Check) AddReadyCheck(check Checker) { + if nc, ok := check.(NamedChecker); ok { + c.readyChecks = append(c.readyChecks, Named(nc.CheckName(), nc)) + } else { + c.readyChecks = append(c.readyChecks, check) + } +} + +// CheckHealth evaluates c's set of health checks and returns a populated Response. +func (c *Check) CheckHealth(ctx context.Context) Response { + response := Response{ + Name: "Health", + Status: StatusPass, + Checks: make(Responses, len(c.healthChecks)), + } + + status, overriding := c.healthOverride.get() + if overriding { + response.Status = status + overrideResponse := Response{ + Name: "manual-override", + Message: "health manually overridden", + } + response.Checks = append(response.Checks, overrideResponse) + } + for i, ch := range c.healthChecks { + resp := ch.Check(ctx) + if resp.Status != StatusPass && !overriding { + response.Status = resp.Status + } + response.Checks[i] = resp + } + sort.Sort(response.Checks) + return response +} + +// CheckReady evaluates c's set of ready checks and returns a populated Response. +func (c *Check) CheckReady(ctx context.Context) Response { + response := Response{ + Name: "Ready", + Status: StatusPass, + Checks: make(Responses, len(c.readyChecks)), + } + + status, overriding := c.readyOverride.get() + if overriding { + response.Status = status + overrideResponse := Response{ + Name: "manual-override", + Message: "ready manually overridden", + } + response.Checks = append(response.Checks, overrideResponse) + } + for i, c := range c.readyChecks { + resp := c.Check(ctx) + if resp.Status != StatusPass && !overriding { + response.Status = resp.Status + } + response.Checks[i] = resp + } + sort.Sort(response.Checks) + return response +} + +// SetPassthrough allows you to set a handler to use if the request is not a ready or health check. +// This can be useful if you intend to use this as a middleware. +func (c *Check) SetPassthrough(h http.Handler) { + c.passthroughHandler = h +} + +// ServeHTTP serves /ready and /health requests with the respective checks. +func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) { + const ( + pathReady = "/ready" + pathHealth = "/health" + queryForce = "force" + ) + + path := r.URL.Path + + // Allow requests not intended for checks to pass through. + if path != pathReady && path != pathHealth { + if c.passthroughHandler != nil { + c.passthroughHandler.ServeHTTP(w, r) + return + } + + // We can't handle this request. + w.WriteHeader(http.StatusNotFound) + return + } + + ctx := r.Context() + query := r.URL.Query() + + switch path { + case pathReady: + switch query.Get(queryForce) { + case "true": + switch query.Get("ready") { + case "true": + c.readyOverride.enable(StatusPass) + case "false": + c.readyOverride.enable(StatusFail) + } + case "false": + c.readyOverride.disable() + } + writeResponse(w, c.CheckReady(ctx)) + case pathHealth: + switch query.Get(queryForce) { + case "true": + switch query.Get("healthy") { + case "true": + c.healthOverride.enable(StatusPass) + case "false": + c.healthOverride.enable(StatusFail) + } + case "false": + c.healthOverride.disable() + } + writeResponse(w, c.CheckHealth(ctx)) + } +} + +// writeResponse writes a Response to the wire as JSON. The HTTP status code +// accompanying the payload is the primary means for signaling the status of the +// checks. The possible status codes are: +// +// - 200 OK: All checks pass. +// - 503 Service Unavailable: Some checks are failing. +// - 500 Internal Server Error: There was a problem serializing the Response. +func writeResponse(w http.ResponseWriter, resp Response) { + status := http.StatusOK + if resp.Status == StatusFail { + status = http.StatusServiceUnavailable + } + + msg, err := json.MarshalIndent(resp, "", " ") + if err != nil { + msg = []byte(`{"message": "error marshaling response", "status": "fail"}`) + status = http.StatusInternalServerError + } + w.WriteHeader(status) + fmt.Fprintln(w, string(msg)) +} + +// override is a manual override for an entire group of checks. +type override struct { + mtx sync.Mutex + status Status + active bool +} + +// get returns the Status of an override as well as whether or not an override +// is currently active. +func (m *override) get() (Status, bool) { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.status, m.active +} + +// disable disables the override. +func (m *override) disable() { + m.mtx.Lock() + m.active = false + m.status = StatusFail + m.mtx.Unlock() +} + +// enable turns on the override and establishes a specific Status for which to. +func (m *override) enable(s Status) { + m.mtx.Lock() + m.active = true + m.status = s + m.mtx.Unlock() +} diff --git a/kit/check/check_test.go b/kit/check/check_test.go new file mode 100644 index 0000000000..90b15073df --- /dev/null +++ b/kit/check/check_test.go @@ -0,0 +1,535 @@ +package check + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +func TestEmptyCheck(t *testing.T) { + c := NewCheck() + resp := c.CheckReady(context.Background()) + if len(resp.Checks) > 0 { + t.Errorf("no checks added but %d returned", len(resp.Checks)) + } + + if resp.Name != "Ready" { + t.Errorf("expected: \"Ready\", got: %q", resp.Name) + } + + if resp.Status != StatusPass { + t.Errorf("expected: %q, got: %q", StatusPass, resp.Status) + } +} + +func TestAddHealthCheck(t *testing.T) { + h := NewCheck() + h.AddHealthCheck(Named("awesome", ErrCheck(func() error { + return nil + }))) + r := h.CheckHealth(context.Background()) + if r.Status != StatusPass { + t.Error("Health should fail because one of the check is unhealthy") + } + + if len(r.Checks) != 1 { + t.Fatalf("check not in results: %+v", r.Checks) + } + + v := r.Checks[0] + if v.Status != StatusPass { + t.Errorf("the added check should be pass not %q.", v.Status) + } +} + +func TestAddUnHealthyCheck(t *testing.T) { + h := NewCheck() + h.AddHealthCheck(Named("failure", ErrCheck(func() error { + return errors.New("Oops! I am sorry") + }))) + r := h.CheckHealth(context.Background()) + if r.Status != StatusFail { + t.Error("Health should fail because one of the check is unhealthy") + } + + if len(r.Checks) != 1 { + t.Fatal("check not in results") + } + + v := r.Checks[0] + if v.Status != StatusFail { + t.Errorf("the added check should be fail not %s.", v.Status) + } + if v.Message != "Oops! I am sorry" { + t.Errorf( + "the error should be 'Oops! I am sorry' not %s.", + v.Message, + ) + } +} + +func buildCheckWithServer() (*Check, *httptest.Server) { + c := NewCheck() + return c, httptest.NewServer(c) +} + +type mockCheck struct { + status Status + name string +} + +func (m mockCheck) Check(_ context.Context) Response { + return Response{ + Name: m.name, + Status: m.status, + } +} + +func mockPass(name string) Checker { + return mockCheck{status: StatusPass, name: name} +} + +func mockFail(name string) Checker { + return mockCheck{status: StatusFail, name: name} +} + +func respBuilder(body io.ReadCloser) (*Response, error) { + defer body.Close() + d := json.NewDecoder(body) + r := &Response{} + return r, d.Decode(r) +} + +func TestBasicHTTPHandler(t *testing.T) { + _, ts := buildCheckWithServer() + defer ts.Close() + + resp, err := http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + expected := &Response{ + Name: "Ready", + Status: StatusPass, + } + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestHealthSorting(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddHealthCheck(mockPass("a")) + c.AddHealthCheck(mockPass("c")) + c.AddHealthCheck(mockPass("b")) + c.AddHealthCheck(mockFail("k")) + c.AddHealthCheck(mockFail("b")) + + resp, err := http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Health", + Status: "fail", + Checks: Responses{ + Response{Name: "b", Status: "fail"}, + Response{Name: "k", Status: "fail"}, + Response{Name: "a", Status: "pass"}, + Response{Name: "b", Status: "pass"}, + Response{Name: "c", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestForceHealthy(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddHealthCheck(mockFail("a")) + + _, err := http.Get(ts.URL + "/health?force=true&healthy=true") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Health", + Status: "pass", + Checks: Responses{ + Response{Name: "manual-override", Message: "health manually overridden"}, + Response{Name: "a", Status: "fail"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + _, err = http.Get(ts.URL + "/health?force=false") + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Health", + Status: "fail", + Checks: Responses{ + Response{Name: "a", Status: "fail"}, + }, + } + + resp, err = http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestForceUnhealthy(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddHealthCheck(mockPass("a")) + + _, err := http.Get(ts.URL + "/health?force=true&healthy=false") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Health", + Status: "fail", + Checks: Responses{ + Response{Name: "manual-override", Message: "health manually overridden"}, + Response{Name: "a", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + _, err = http.Get(ts.URL + "/health?force=false") + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Health", + Status: "pass", + Checks: Responses{ + Response{Name: "a", Status: "pass"}, + }, + } + + resp, err = http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestForceReady(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddReadyCheck(mockFail("a")) + + _, err := http.Get(ts.URL + "/ready?force=true&ready=true") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Ready", + Status: "pass", + Checks: Responses{ + Response{Name: "manual-override", Message: "ready manually overridden"}, + Response{Name: "a", Status: "fail"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + _, err = http.Get(ts.URL + "/ready?force=false") + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Ready", + Status: "fail", + Checks: Responses{ + Response{Name: "a", Status: "fail"}, + }, + } + + resp, err = http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestForceNotReady(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddReadyCheck(mockPass("a")) + + _, err := http.Get(ts.URL + "/ready?force=true&ready=false") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Ready", + Status: "fail", + Checks: Responses{ + Response{Name: "manual-override", Message: "ready manually overridden"}, + Response{Name: "a", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + _, err = http.Get(ts.URL + "/ready?force=false") + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Ready", + Status: "pass", + Checks: Responses{ + Response{Name: "a", Status: "pass"}, + }, + } + + resp, err = http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestNoCrossOver(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddHealthCheck(mockPass("a")) + c.AddHealthCheck(mockPass("c")) + c.AddReadyCheck(mockPass("b")) + c.AddReadyCheck(mockFail("k")) + c.AddHealthCheck(mockFail("b")) + + resp, err := http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Health", + Status: "fail", + Checks: Responses{ + Response{Name: "b", Status: "fail"}, + Response{Name: "a", Status: "pass"}, + Response{Name: "c", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + resp, err = http.Get(ts.URL + "/ready") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Ready", + Status: "fail", + Checks: Responses{ + Response{Name: "k", Status: "fail"}, + Response{Name: "b", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + +func TestPassthrough(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 404 { + t.Fatalf("failed to error when no passthrough is present, status: %d", resp.StatusCode) + } + + used := false + s := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used = true + w.Write([]byte("hi")) + }) + + c.SetPassthrough(s) + + resp, err = http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 200 { + t.Fatalf("bad response code from passthrough, status: %d", resp.StatusCode) + } + + if !used { + t.Fatal("passthrough server not used") + } +} + +func ExampleNewCheck() { + // Run the default healthcheck. it always return 200. It is good if you + // have a service without any dependency + h := NewCheck() + h.CheckHealth(context.Background()) +} + +func ExampleCheck_CheckHealth() { + h := NewCheck() + h.AddHealthCheck(Named("google", CheckerFunc(func(ctx context.Context) Response { + var r net.Resolver + _, err := r.LookupHost(ctx, "google.com") + if err != nil { + return Error(err) + } + return Pass() + }))) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + h.CheckHealth(ctx) +} + +func ExampleCheck_ServeHTTP() { + c := NewCheck() + http.ListenAndServe(":6060", c) +} + +func ExampleCheck_SetPassthrough() { + c := NewCheck() + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello friends!")) + }) + + c.SetPassthrough(http.DefaultServeMux) + http.ListenAndServe(":6060", c) +} diff --git a/kit/check/helpers.go b/kit/check/helpers.go new file mode 100644 index 0000000000..87e9288ef5 --- /dev/null +++ b/kit/check/helpers.go @@ -0,0 +1,73 @@ +package check + +import ( + "context" + "fmt" +) + +// NamedChecker is a superset of Checker that also indicates the name of the service. +// Prefer to implement NamedChecker if your service has a fixed name, +// as opposed to calling *Health.AddNamed. +type NamedChecker interface { + Checker + CheckName() string +} + +// CheckerFunc is an adapter of a plain func() error to the Checker interface. +type CheckerFunc func(ctx context.Context) Response + +// Check implements Checker. +func (f CheckerFunc) Check(ctx context.Context) Response { + return f(ctx) +} + +// Named returns a Checker that will attach a name to the Response from the check. +// This way, it is possible to augment a Response with a human-readable name, but not have to encode +// that logic in the actual check itself. +func Named(name string, checker Checker) Checker { + return CheckerFunc(func(ctx context.Context) Response { + resp := checker.Check(ctx) + resp.Name = name + return resp + }) +} + +// NamedFunc is the same as Named except it takes a CheckerFunc. +func NamedFunc(name string, fn CheckerFunc) Checker { + return Named(name, fn) +} + +// ErrCheck will create a health checker that executes a function. If the function returns an error, +// it will return an unhealthy response. Otherwise, it will be as if the Ok function was called. +// Note: it is better to use CheckFunc, because with Check, the context is ignored. +func ErrCheck(fn func() error) Checker { + return CheckerFunc(func(_ context.Context) Response { + if err := fn(); err != nil { + return Error(err) + } + return Pass() + }) +} + +// Pass is a utility function to generate a passing status response with the default parameters. +func Pass() Response { + return Response{ + Status: StatusPass, + } +} + +// Info is a utility function to generate a healthy status with a printf message. +func Info(msg string, args ...interface{}) Response { + return Response{ + Status: StatusPass, + Message: fmt.Sprintf(msg, args...), + } +} + +// Error is a utility function for creating a response from an error message. +func Error(err error) Response { + return Response{ + Status: StatusFail, + Message: err.Error(), + } +} diff --git a/kit/check/response.go b/kit/check/response.go new file mode 100644 index 0000000000..9eb56f0d4a --- /dev/null +++ b/kit/check/response.go @@ -0,0 +1,39 @@ +package check + +// Response is a result of a collection of health checks. +type Response struct { + Name string `json:"name"` + Status Status `json:"status"` + Message string `json:"message,omitempty"` + Checks Responses `json:"checks,omitempty"` +} + +// HasCheck verifies whether the receiving Response has a check with the given name or not. +func (r *Response) HasCheck(name string) bool { + found := false + for _, check := range r.Checks { + if check.Name == name { + found = true + break + } + } + return found +} + +// Responses is a sortable collection of Response objects. +type Responses []Response + +func (r Responses) Len() int { return len(r) } + +// Less defines the order in which responses are sorted. +// +// Failing responses are always sorted before passing responses. Responses with +// the same status are then sorted according to the name of the check. +func (r Responses) Less(i, j int) bool { + if r[i].Status == r[j].Status { + return r[i].Name < r[j].Name + } + return r[i].Status < r[j].Status +} + +func (r Responses) Swap(i, j int) { r[i], r[j] = r[j], r[i] } diff --git a/kit/errors/errors.go b/kit/errors/errors.go new file mode 100644 index 0000000000..8c12720cd1 --- /dev/null +++ b/kit/errors/errors.go @@ -0,0 +1,107 @@ +package errors + +import ( + "fmt" + "net/http" +) + +// TODO: move to base directory + +const ( + // InternalError indicates an unexpected error condition. + InternalError = 1 + // MalformedData indicates malformed input, such as unparsable JSON. + MalformedData = 2 + // InvalidData indicates that data is well-formed, but invalid. + InvalidData = 3 + // Forbidden indicates a forbidden operation. + Forbidden = 4 + // NotFound indicates a resource was not found. + NotFound = 5 +) + +// Error indicates an error with a reference code and an HTTP status code. +type Error struct { + Reference int `json:"referenceCode"` + Code int `json:"statusCode"` + Err string `json:"err"` +} + +// Error implements the error interface. +func (e Error) Error() string { + return e.Err +} + +// Errorf constructs an Error with the given reference code and format. +func Errorf(ref int, format string, i ...interface{}) error { + return Error{ + Reference: ref, + Err: fmt.Sprintf(format, i...), + } +} + +// New creates a new error with a message and error code. +func New(msg string, ref ...int) error { + refCode := InternalError + if len(ref) == 1 { + refCode = ref[0] + } + return Error{ + Reference: refCode, + Err: msg, + } +} + +func Wrap(err error, msg string, ref ...int) error { + if err == nil { + return nil + } + e, ok := err.(Error) + if ok { + refCode := e.Reference + if len(ref) == 1 { + refCode = ref[0] + } + return Error{ + Reference: refCode, + Code: e.Code, + Err: fmt.Sprintf("%s: %s", msg, e.Err), + } + } + refCode := InternalError + if len(ref) == 1 { + refCode = ref[0] + } + return Error{ + Reference: refCode, + Err: fmt.Sprintf("%s: %s", msg, err.Error()), + } +} + +// InternalErrorf constructs an InternalError with the given format. +func InternalErrorf(format string, i ...interface{}) error { + return Errorf(InternalError, format, i...) +} + +// MalformedDataf constructs a MalformedData error with the given format. +func MalformedDataf(format string, i ...interface{}) error { + return Errorf(MalformedData, format, i...) +} + +// InvalidDataf constructs an InvalidData error with the given format. +func InvalidDataf(format string, i ...interface{}) error { + return Errorf(InvalidData, format, i...) +} + +// Forbiddenf constructs a Forbidden error with the given format. +func Forbiddenf(format string, i ...interface{}) error { + return Errorf(Forbidden, format, i...) +} + +func BadRequestError(msg string) error { + return Error{ + Reference: InvalidData, + Code: http.StatusBadRequest, + Err: msg, + } +} diff --git a/kit/errors/list.go b/kit/errors/list.go new file mode 100644 index 0000000000..c4636135c7 --- /dev/null +++ b/kit/errors/list.go @@ -0,0 +1,59 @@ +package errors + +import ( + "errors" + "strings" +) + +// List represents a list of errors. +type List struct { + errs []error + err error // cached error +} + +// Append adds err to the errors list. +func (l *List) Append(err error) { + l.errs = append(l.errs, err) + l.err = nil +} + +// AppendString adds a new error that formats as the given text. +func (l *List) AppendString(text string) { + l.errs = append(l.errs, errors.New(text)) + l.err = nil +} + +// Clear removes all the previously appended errors from the list. +func (l *List) Clear() { + for i := range l.errs { + l.errs[i] = nil + } + l.errs = l.errs[:0] + l.err = nil +} + +// Err returns an error composed of the list of errors, separated by a new line, or nil if no errors +// were appended. +func (l *List) Err() error { + if len(l.errs) == 0 { + return nil + } + + if l.err != nil { + switch len(l.errs) { + case 1: + l.err = l.errs[0] + + default: + var sb strings.Builder + sb.WriteString(l.errs[0].Error()) + for _, err := range l.errs[1:] { + sb.WriteRune('\n') + sb.WriteString(err.Error()) + } + l.err = errors.New(sb.String()) + } + } + + return l.err +} diff --git a/kit/feature/_codegen/main.go b/kit/feature/_codegen/main.go new file mode 100644 index 0000000000..33aab7b2ab --- /dev/null +++ b/kit/feature/_codegen/main.go @@ -0,0 +1,271 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "io/ioutil" + "os" + "strings" + "text/template" + + "github.com/Masterminds/sprig" + "github.com/influxdata/influxdb/kit/feature" + yaml "gopkg.in/yaml.v2" +) + +const tmpl = `// Code generated by the feature package; DO NOT EDIT. + +package feature + +{{ .Qualify | import }} + +{{ range $_, $flag := .Flags }} +var {{ $flag.Key }} = {{ $.Qualify | package }}{{ $flag.Default | maker }}( + {{ $flag.Name | quote }}, + {{ $flag.Key | quote }}, + {{ $flag.Contact | quote }}, + {{ $flag.Default | conditionalQuote }}, + {{ $.Qualify | package }}{{ $flag.Lifetime | lifetime }}, + {{ $flag.Expose }}, +) + +// {{ $flag.Name | replace " " "_" | camelcase }} - {{ $flag.Description }} +func {{ $flag.Name | replace " " "_" | camelcase }}() {{ $.Qualify | package }}{{ $flag.Default | flagType }} { + return {{ $flag.Key }} +} +{{ end }} + +var all = []{{ .Qualify | package }}Flag{ +{{ range $_, $flag := .Flags }} {{ $flag.Key }}, +{{ end }}} + +var byKey = map[string]{{ $.Qualify | package }}Flag{ +{{ range $_, $flag := .Flags }} {{ $flag.Key | quote }}: {{ $flag.Key }}, +{{ end }}} +` + +type flagConfig struct { + Name string + Description string + Key string + Default interface{} + Contact string + Lifetime feature.Lifetime + Expose bool +} + +func (f flagConfig) Valid() error { + var problems []string + if f.Key == "" { + problems = append(problems, "missing key") + } + if f.Contact == "" { + problems = append(problems, "missing contact") + } + if f.Default == nil { + problems = append(problems, "missing default") + } + if f.Description == "" { + problems = append(problems, "missing description") + } + + if len(problems) > 0 { + name := f.Name + if name == "" { + if f.Key != "" { + name = f.Key + } else { + name = "anonymous" + } + } + // e.g. "my flag: missing key; missing default" + return fmt.Errorf("%s: %s\n", name, strings.Join(problems, "; ")) + } + return nil +} + +type flagValidationError struct { + errs []error +} + +func newFlagValidationError(errs []error) *flagValidationError { + if len(errs) == 0 { + return nil + } + return &flagValidationError{errs} +} + +func (e *flagValidationError) Error() string { + var s strings.Builder + s.WriteString("flag validation error: \n") + for _, err := range e.errs { + s.WriteString(err.Error()) + } + return s.String() +} + +func validate(flags []flagConfig) error { + var ( + errs []error + seen = make(map[string]bool, len(flags)) + ) + for _, flag := range flags { + if err := flag.Valid(); err != nil { + errs = append(errs, err) + } else if _, repeated := seen[flag.Key]; repeated { + errs = append(errs, fmt.Errorf("duplicate flag key '%s'\n", flag.Key)) + } + seen[flag.Key] = true + } + if len(errs) != 0 { + return newFlagValidationError(errs) + } + + return nil +} + +var argv = struct { + in, out *string + qualify *bool +}{ + in: flag.String("in", "", "flag configuration path"), + out: flag.String("out", "", "flag generation destination path"), + qualify: flag.Bool("qualify", false, "qualify types with imported package name"), +} + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + os.Exit(0) +} + +func run() error { + flag.Parse() + + in, err := os.Open(*argv.in) + if err != nil { + return err + } + defer in.Close() + + configuration, err := ioutil.ReadAll(in) + if err != nil { + return err + } + + var flags []flagConfig + err = yaml.Unmarshal(configuration, &flags) + if err != nil { + return err + } + err = validate(flags) + if err != nil { + return err + } + + t, err := template.New("flags").Funcs(templateFunctions()).Parse(tmpl) + if err != nil { + return err + } + + out, err := os.Create(*argv.out) + if err != nil { + return err + } + defer out.Close() + + var ( + buf = new(bytes.Buffer) + vars = struct { + Qualify bool + Flags []flagConfig + }{ + Qualify: *argv.qualify, + Flags: flags, + } + ) + if err := t.Execute(buf, vars); err != nil { + return err + } + + raw, err := ioutil.ReadAll(buf) + if err != nil { + return err + } + + formatted, err := format.Source(raw) + if err != nil { + return err + } + + _, err = out.Write(formatted) + return err +} + +func templateFunctions() template.FuncMap { + functions := sprig.TxtFuncMap() + + functions["lifetime"] = func(t interface{}) string { + switch t { + case feature.Permanent: + return "Permanent" + default: + return "Temporary" + } + } + + functions["conditionalQuote"] = func(t interface{}) string { + switch t.(type) { + case string: + return fmt.Sprintf("%q", t) + default: + return fmt.Sprintf("%v", t) + } + } + + functions["flagType"] = func(t interface{}) string { + switch t.(type) { + case bool: + return "BoolFlag" + case float64: + return "FloatFlag" + case int: + return "IntFlag" + default: + return "StringFlag" + } + } + + functions["maker"] = func(t interface{}) string { + switch t.(type) { + case bool: + return "MakeBoolFlag" + case float64: + return "MakeFloatFlag" + case int: + return "MakeIntFlag" + default: + return "MakeStringFlag" + } + } + + functions["package"] = func(t interface{}) string { + if t.(bool) { + return "feature." + } + return "" + } + + functions["import"] = func(t interface{}) string { + if t.(bool) { + return "import \"github.com/influxdata/influxdb/kit/feature\"" + } + return "" + } + + return functions +} diff --git a/kit/feature/doc.go b/kit/feature/doc.go new file mode 100644 index 0000000000..3ad8f59719 --- /dev/null +++ b/kit/feature/doc.go @@ -0,0 +1,75 @@ +// Package feature provides feature flagging capabilities for InfluxDB servers. +// This document describes this package and how it is used to control +// experimental features in `influxd`. +// +// Flags are configured in `flags.yml` at the top of this repository. +// Running `make flags` generates Go code based on this configuration +// to programmatically test flag values in a given request context. +// Boolean flags are the most common case, but integers, floats and +// strings are supported for more complicated experiments. +// +// The `Flagger` interface is the crux of this package. +// It computes a map of feature flag values for a given request context. +// The default implementation always returns the flag default configured +// in `flags.yml`. The override implementation allows an operator to +// override feature flag defaults at startup. Changing these overrides +// requires a restart. +// +// In `influxd`, a `Flagger` instance is provided to a `Handler` middleware +// configured to intercept all API requests and annotate their request context +// with a map of feature flags. +// +// A flag can opt in to be exposed externally in `flags.yml`. If exposed, +// this flag will be included in the response from the `/api/v2/flags` +// endpoint. This allows the UI and other API clients to control their +// behavior according to the flag in addition to the server itself. +// +// A concrete example to illustrate the above: +// +// I have a feature called "My Feature" that will involve turning on new code +// in both the UI and the server. +// +// First, I add an entry to `flags.yml`. +// +// ```yaml +// - name: My Feature +// description: My feature is awesome +// key: myFeature +// default: false +// expose: true +// contact: My Name +// ``` +// +// My flag type is inferred to be boolean by my default of `false` when I run +// `make flags` and the `feature` package now includes `func MyFeature() BoolFlag`. +// +// I use this to control my backend code with +// +// ```go +// if feature.MyFeature.Enabled(ctx) { +// // new code... +// } else { +// // new code... +// } +// ``` +// +// and the `/api/v2/flags` response provides the same information to the frontend. +// +// ```json +// { +// "myFeature": false +// } +// ``` +// +// While `false` by default, I can turn on my experimental feature by starting +// my server with a flag override. +// +// ``` +// env INFLUXD_FEATURE_FLAGS="{\"flag1\":\value1\",\"key2\":\"value2\"}" influxd +// ``` +// +// ``` +// influxd --feature-flags flag1=value1,flag2=value2 +// ``` +// +package feature diff --git a/kit/feature/feature.go b/kit/feature/feature.go new file mode 100644 index 0000000000..7a84c3c23c --- /dev/null +++ b/kit/feature/feature.go @@ -0,0 +1,140 @@ +package feature + +import ( + "context" + "strings" + + "github.com/opentracing/opentracing-go" +) + +type contextKey string + +const featureContextKey contextKey = "influx/feature/v1" + +// Flagger returns flag values. +type Flagger interface { + // Flags returns a map of flag keys to flag values. + // + // If an authorization is present on the context, it may be used to compute flag + // values according to the affiliated user ID and its organization and other mappings. + // Otherwise, they should be computed generally or return a default. + // + // One or more flags may be provided to restrict the results. + // Otherwise, all flags should be computed. + Flags(context.Context, ...Flag) (map[string]interface{}, error) +} + +// Annotate the context with a map computed of computed flags. +func Annotate(ctx context.Context, f Flagger, flags ...Flag) (context.Context, error) { + computed, err := f.Flags(ctx, flags...) + if err != nil { + return nil, err + } + + span := opentracing.SpanFromContext(ctx) + if span != nil { + for k, v := range computed { + span.LogKV(k, v) + } + } + + return context.WithValue(ctx, featureContextKey, computed), nil +} + +// FlagsFromContext returns the map of flags attached to the context +// by Annotate, or nil if none is found. +func FlagsFromContext(ctx context.Context) map[string]interface{} { + v, ok := ctx.Value(featureContextKey).(map[string]interface{}) + if !ok { + return nil + } + return v +} + +type ByKeyFn func(string) (Flag, bool) + +// ExposedFlagsFromContext returns the filtered map of exposed flags attached +// to the context by Annotate, or nil if none is found. +func ExposedFlagsFromContext(ctx context.Context, byKey ByKeyFn) map[string]interface{} { + m := FlagsFromContext(ctx) + if m == nil { + return nil + } + + filtered := make(map[string]interface{}) + for k, v := range m { + if flag, found := byKey(k); found && flag.Expose() { + filtered[k] = v + } + } + + return filtered +} + +// Lifetime represents the intended lifetime of the feature flag. +// +// The zero value is Temporary, the most common case, but Permanent +// is included to mark special cases where a flag is not intended +// to be removed, e.g. enabling debug tracing for an organization. +// +// TODO(gavincabbage): This may become a stale date, which can then +// be used to trigger a notification to the contact when the flag +// has become stale, to encourage flag cleanup. +type Lifetime int + +const ( + // Temporary indicates a flag is intended to be removed after a feature is no longer in development. + Temporary Lifetime = iota + // Permanent indicates a flag is not intended to be removed. + Permanent +) + +// UnmarshalYAML implements yaml.Unmarshaler and interprets a case-insensitive text +// representation as a lifetime constant. +func (l *Lifetime) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + switch strings.ToLower(s) { + case "permanent": + *l = Permanent + default: + *l = Temporary + } + + return nil +} + +type defaultFlagger struct{} + +// DefaultFlagger returns a flagger that always returns default values. +func DefaultFlagger() Flagger { + return &defaultFlagger{} +} + +// Flags returns a map of default values. It never returns an error. +func (*defaultFlagger) Flags(_ context.Context, flags ...Flag) (map[string]interface{}, error) { + if len(flags) == 0 { + flags = Flags() + } + + m := make(map[string]interface{}, len(flags)) + for _, flag := range flags { + m[flag.Key()] = flag.Default() + } + + return m, nil +} + +// Flags returns all feature flags. +func Flags() []Flag { + return all +} + +// ByKey returns the Flag corresponding to the given key. +func ByKey(k string) (Flag, bool) { + v, found := byKey[k] + return v, found +} diff --git a/kit/feature/feature_test.go b/kit/feature/feature_test.go new file mode 100644 index 0000000000..4f5493d43e --- /dev/null +++ b/kit/feature/feature_test.go @@ -0,0 +1,185 @@ +package feature_test + +import ( + "context" + "testing" + + "github.com/influxdata/influxdb/kit/feature" +) + +func Test_feature(t *testing.T) { + + cases := []struct { + name string + flag feature.Flag + err error + values map[string]interface{} + ctx context.Context + expected interface{} + }{ + { + name: "bool happy path", + flag: newFlag("test", false), + values: map[string]interface{}{ + "test": true, + }, + expected: true, + }, + { + name: "int happy path", + flag: newFlag("test", 0), + values: map[string]interface{}{ + "test": int32(42), + }, + expected: int32(42), + }, + { + name: "float happy path", + flag: newFlag("test", 0.0), + values: map[string]interface{}{ + "test": 42.42, + }, + expected: 42.42, + }, + { + name: "string happy path", + flag: newFlag("test", ""), + values: map[string]interface{}{ + "test": "restaurantattheendoftheuniverse", + }, + expected: "restaurantattheendoftheuniverse", + }, + { + name: "bool missing use default", + flag: newFlag("test", false), + expected: false, + }, + { + name: "bool missing use default true", + flag: newFlag("test", true), + expected: true, + }, + { + name: "int missing use default", + flag: newFlag("test", 65), + expected: int32(65), + }, + { + name: "float missing use default", + flag: newFlag("test", 65.65), + expected: 65.65, + }, + { + name: "string missing use default", + flag: newFlag("test", "mydefault"), + expected: "mydefault", + }, + + { + name: "bool invalid use default", + flag: newFlag("test", true), + values: map[string]interface{}{ + "test": "notabool", + }, + expected: true, + }, + { + name: "int invalid use default", + flag: newFlag("test", 42), + values: map[string]interface{}{ + "test": 99.99, + }, + expected: int32(42), + }, + { + name: "float invalid use default", + flag: newFlag("test", 42.42), + values: map[string]interface{}{ + "test": 99, + }, + expected: 42.42, + }, + { + name: "string invalid use default", + flag: newFlag("test", "restaurantattheendoftheuniverse"), + values: map[string]interface{}{ + "test": true, + }, + expected: "restaurantattheendoftheuniverse", + }, + } + + for _, test := range cases { + t.Run("flagger "+test.name, func(t *testing.T) { + flagger := testFlagsFlagger{ + m: test.values, + err: test.err, + } + + var actual interface{} + switch flag := test.flag.(type) { + case feature.BoolFlag: + actual = flag.Enabled(test.ctx, flagger) + case feature.FloatFlag: + actual = flag.Float(test.ctx, flagger) + case feature.IntFlag: + actual = flag.Int(test.ctx, flagger) + case feature.StringFlag: + actual = flag.String(test.ctx, flagger) + default: + t.Errorf("unknown flag type %T (%#v)", flag, flag) + } + + if actual != test.expected { + t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected) + } + }) + + t.Run("annotate "+test.name, func(t *testing.T) { + flagger := testFlagsFlagger{ + m: test.values, + err: test.err, + } + + ctx, err := feature.Annotate(context.Background(), flagger) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + var actual interface{} + switch flag := test.flag.(type) { + case feature.BoolFlag: + actual = flag.Enabled(ctx) + case feature.FloatFlag: + actual = flag.Float(ctx) + case feature.IntFlag: + actual = flag.Int(ctx) + case feature.StringFlag: + actual = flag.String(ctx) + default: + t.Errorf("unknown flag type %T (%#v)", flag, flag) + } + + if actual != test.expected { + t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected) + } + }) + } +} + +type testFlagsFlagger struct { + m map[string]interface{} + err error +} + +func (f testFlagsFlagger) Flags(ctx context.Context, flags ...feature.Flag) (map[string]interface{}, error) { + if f.err != nil { + return nil, f.err + } + + return f.m, nil +} + +func newFlag(key string, defaultValue interface{}) feature.Flag { + return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false) +} diff --git a/kit/feature/flag.go b/kit/feature/flag.go new file mode 100644 index 0000000000..045e8a53a4 --- /dev/null +++ b/kit/feature/flag.go @@ -0,0 +1,216 @@ +//go:generate go run ./_codegen/main.go --in ../../flags.yml --out ./list.go + +package feature + +import ( + "context" + "fmt" +) + +// Flag represents a generic feature flag with a key and a default. +type Flag interface { + // Key returns the programmatic backend identifier for the flag. + Key() string + // Default returns the type-agnostic zero value for the flag. + // Type-specific flag implementations may expose a typed default + // (e.g. BoolFlag includes a boolean Default field). + Default() interface{} + // Expose the flag. + Expose() bool +} + +// MakeFlag constructs a Flag. The concrete implementation is inferred from the provided default. +func MakeFlag(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Flag { + b := MakeBase(name, key, owner, defaultValue, lifetime, expose) + switch v := defaultValue.(type) { + case bool: + return BoolFlag{b, v} + case float64: + return FloatFlag{b, v} + case int32: + return IntFlag{b, v} + case int: + return IntFlag{b, int32(v)} + case string: + return StringFlag{b, v} + default: + return StringFlag{b, fmt.Sprintf("%v", v)} + } +} + +// flag base type. +type Base struct { + // name of the flag. + name string + // key is the programmatic backend identifier for the flag. + key string + // defaultValue for the flag. + defaultValue interface{} + // owner is an individual or team responsible for the flag. + owner string + // lifetime of the feature flag. + lifetime Lifetime + // expose the flag. + expose bool +} + +var _ Flag = Base{} + +// MakeBase constructs a flag flag. +func MakeBase(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Base { + return Base{ + name: name, + key: key, + owner: owner, + defaultValue: defaultValue, + lifetime: lifetime, + expose: expose, + } +} + +// Key returns the programmatic backend identifier for the flag. +func (f Base) Key() string { + return f.key +} + +// Default returns the type-agnostic zero value for the flag. +func (f Base) Default() interface{} { + return f.defaultValue +} + +// Expose the flag. +func (f Base) Expose() bool { + return f.expose +} + +func (f Base) value(ctx context.Context, flagger ...Flagger) (interface{}, bool) { + var ( + m map[string]interface{} + ok bool + ) + if len(flagger) < 1 { + m, ok = ctx.Value(featureContextKey).(map[string]interface{}) + } else { + var err error + m, err = flagger[0].Flags(ctx, f) + ok = err == nil + } + if !ok { + return nil, false + } + + v, ok := m[f.Key()] + if !ok { + return nil, false + } + + return v, true +} + +// StringFlag implements Flag for string values. +type StringFlag struct { + Base + defaultString string +} + +var _ Flag = StringFlag{} + +// MakeStringFlag returns a string flag with the given Base and default. +func MakeStringFlag(name, key, owner string, defaultValue string, lifetime Lifetime, expose bool) StringFlag { + b := MakeBase(name, key, owner, defaultValue, lifetime, expose) + return StringFlag{b, defaultValue} +} + +// String value of the flag on the request context. +func (f StringFlag) String(ctx context.Context, flagger ...Flagger) string { + i, ok := f.value(ctx, flagger...) + if !ok { + return f.defaultString + } + s, ok := i.(string) + if !ok { + return f.defaultString + } + return s +} + +// FloatFlag implements Flag for float values. +type FloatFlag struct { + Base + defaultFloat float64 +} + +var _ Flag = FloatFlag{} + +// MakeFloatFlag returns a string flag with the given Base and default. +func MakeFloatFlag(name, key, owner string, defaultValue float64, lifetime Lifetime, expose bool) FloatFlag { + b := MakeBase(name, key, owner, defaultValue, lifetime, expose) + return FloatFlag{b, defaultValue} +} + +// Float value of the flag on the request context. +func (f FloatFlag) Float(ctx context.Context, flagger ...Flagger) float64 { + i, ok := f.value(ctx, flagger...) + if !ok { + return f.defaultFloat + } + v, ok := i.(float64) + if !ok { + return f.defaultFloat + } + return v +} + +// IntFlag implements Flag for integer values. +type IntFlag struct { + Base + defaultInt int32 +} + +var _ Flag = IntFlag{} + +// MakeIntFlag returns a string flag with the given Base and default. +func MakeIntFlag(name, key, owner string, defaultValue int32, lifetime Lifetime, expose bool) IntFlag { + b := MakeBase(name, key, owner, defaultValue, lifetime, expose) + return IntFlag{b, defaultValue} +} + +// Int value of the flag on the request context. +func (f IntFlag) Int(ctx context.Context, flagger ...Flagger) int32 { + i, ok := f.value(ctx, flagger...) + if !ok { + return f.defaultInt + } + v, ok := i.(int32) + if !ok { + return f.defaultInt + } + return v +} + +// BoolFlag implements Flag for boolean values. +type BoolFlag struct { + Base + defaultBool bool +} + +var _ Flag = BoolFlag{} + +// MakeBoolFlag returns a string flag with the given Base and default. +func MakeBoolFlag(name, key, owner string, defaultValue bool, lifetime Lifetime, expose bool) BoolFlag { + b := MakeBase(name, key, owner, defaultValue, lifetime, expose) + return BoolFlag{b, defaultValue} +} + +// Enabled indicates whether flag is true or false on the request context. +func (f BoolFlag) Enabled(ctx context.Context, flagger ...Flagger) bool { + i, ok := f.value(ctx, flagger...) + if !ok { + return f.defaultBool + } + v, ok := i.(bool) + if !ok { + return f.defaultBool + } + return v +} diff --git a/kit/feature/http_proxy.go b/kit/feature/http_proxy.go new file mode 100644 index 0000000000..ba87752664 --- /dev/null +++ b/kit/feature/http_proxy.go @@ -0,0 +1,73 @@ +package feature + +import ( + "context" + "net/http" + "net/http/httputil" + "net/url" + + "go.uber.org/zap" +) + +// HTTPProxy is an HTTP proxy that's guided by a feature flag. If the feature flag +// presented to it is enabled, it will perform the proxying behavior. Otherwise +// it will be a no-op. +type HTTPProxy struct { + proxy *httputil.ReverseProxy + logger *zap.Logger + enabler ProxyEnabler +} + +// NewHTTPProxy returns a new Proxy. +func NewHTTPProxy(dest *url.URL, logger *zap.Logger, enabler ProxyEnabler) *HTTPProxy { + return &HTTPProxy{ + proxy: newReverseProxy(dest, enabler.Key()), + logger: logger, + enabler: enabler, + } +} + +// Do performs the proxying. It returns whether or not the request was proxied. +func (p *HTTPProxy) Do(w http.ResponseWriter, r *http.Request) bool { + if p.enabler.Enabled(r.Context()) { + p.proxy.ServeHTTP(w, r) + return true + } + return false +} + +const ( + // headerProxyFlag is the HTTP header for enriching the request and response + // with the feature flag key that precipitated the proxying behavior. + headerProxyFlag = "X-Platform-Proxy-Flag" +) + +// newReverseProxy creates a new single-host reverse proxy. +func newReverseProxy(dest *url.URL, enablerKey string) *httputil.ReverseProxy { + proxy := httputil.NewSingleHostReverseProxy(dest) + + defaultDirector := proxy.Director + proxy.Director = func(r *http.Request) { + defaultDirector(r) + + r.Header.Set(headerProxyFlag, enablerKey) + + // Override r.Host to prevent us sending this request back to ourselves. + // A bug in the stdlib causes this value to be preferred over the + // r.URL.Host (which is set in the default Director) if r.Host isn't + // empty (which it isn't). + // https://github.com/golang/go/issues/28168 + r.Host = dest.Host + } + proxy.ModifyResponse = func(r *http.Response) error { + r.Header.Set(headerProxyFlag, enablerKey) + return nil + } + return proxy +} + +// ProxyEnabler is a boolean feature flag. +type ProxyEnabler interface { + Key() string + Enabled(ctx context.Context, fs ...Flagger) bool +} diff --git a/kit/feature/http_proxy_test.go b/kit/feature/http_proxy_test.go new file mode 100644 index 0000000000..a0fd57cf33 --- /dev/null +++ b/kit/feature/http_proxy_test.go @@ -0,0 +1,137 @@ +package feature + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +const ( + destBody = "hello from destination" + srcBody = "hello from source" + flagKey = "fancy-feature" +) + +func TestHTTPProxy_Proxying(t *testing.T) { + en := enabler{key: flagKey, state: true} + logger := zaptest.NewLogger(t) + resp, err := testHTTPProxy(logger, en) + if err != nil { + t.Error(err) + } + + proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag") + if proxyFlag != flagKey { + t.Error("X-Platform-Proxy-Flag header not populated") + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + + bodyStr := string(body) + if bodyStr != destBody { + t.Errorf("expected body of destination handler, but got: %q", bodyStr) + } +} + +func TestHTTPProxy_DefaultBehavior(t *testing.T) { + en := enabler{key: flagKey, state: false} + logger := zaptest.NewLogger(t) + resp, err := testHTTPProxy(logger, en) + if err != nil { + t.Error(err) + } + + proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag") + if proxyFlag != "" { + t.Error("X-Platform-Proxy-Flag header populated") + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + + bodyStr := string(body) + if bodyStr != srcBody { + t.Errorf("expected body of source handler, but got: %q", bodyStr) + } +} + +func TestHTTPProxy_RequestHeader(t *testing.T) { + h := func(w http.ResponseWriter, r *http.Request) { + proxyFlag := r.Header.Get("X-Platform-Proxy-Flag") + if proxyFlag != flagKey { + t.Error("expected X-Proxy-Flag to contain feature flag key") + } + } + + s := httptest.NewServer(http.HandlerFunc(h)) + defer s.Close() + + sURL, err := url.Parse(s.URL) + if err != nil { + t.Error(err) + } + + logger := zaptest.NewLogger(t) + en := enabler{key: flagKey, state: true} + proxy := NewHTTPProxy(sURL, logger, en) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + srcHandler(proxy)(w, r) +} + +func testHTTPProxy(logger *zap.Logger, enabler ProxyEnabler) (*http.Response, error) { + s := httptest.NewServer(http.HandlerFunc(destHandler)) + defer s.Close() + + sURL, err := url.Parse(s.URL) + if err != nil { + return nil, err + } + + proxy := NewHTTPProxy(sURL, logger, enabler) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) + srcHandler(proxy)(w, r) + + return w.Result(), nil +} + +func destHandler(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, destBody) +} + +func srcHandler(proxy *HTTPProxy) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if proxy.Do(w, r) { + return + } + fmt.Fprint(w, srcBody) + } +} + +type enabler struct { + key string + state bool +} + +func (e enabler) Key() string { + return e.key +} + +func (e enabler) Enabled(context.Context, ...Flagger) bool { + return e.state +} diff --git a/kit/feature/list.go b/kit/feature/list.go new file mode 100644 index 0000000000..088f0b23d4 --- /dev/null +++ b/kit/feature/list.go @@ -0,0 +1,249 @@ +// Code generated by the feature package; DO NOT EDIT. + +package feature + +var appMetrics = MakeBoolFlag( + "App Metrics", + "appMetrics", + "Bucky, Monitoring Team", + false, + Permanent, + true, +) + +// AppMetrics - Send UI Telementry to Tools cluster - should always be false in OSS +func AppMetrics() BoolFlag { + return appMetrics +} + +var backendExample = MakeBoolFlag( + "Backend Example", + "backendExample", + "Gavin Cabbage", + false, + Permanent, + false, +) + +// BackendExample - A permanent backend example boolean flag +func BackendExample() BoolFlag { + return backendExample +} + +var communityTemplates = MakeBoolFlag( + "Community Templates", + "communityTemplates", + "Bucky", + true, + Permanent, + true, +) + +// CommunityTemplates - Replace current template uploading functionality with community driven templates +func CommunityTemplates() BoolFlag { + return communityTemplates +} + +var frontendExample = MakeIntFlag( + "Frontend Example", + "frontendExample", + "Gavin Cabbage", + 42, + Temporary, + true, +) + +// FrontendExample - A temporary frontend example integer flag +func FrontendExample() IntFlag { + return frontendExample +} + +var groupWindowAggregateTranspose = MakeBoolFlag( + "Group Window Aggregate Transpose", + "groupWindowAggregateTranspose", + "Query Team", + false, + Temporary, + false, +) + +// GroupWindowAggregateTranspose - Enables the GroupWindowAggregateTransposeRule for all enabled window aggregates +func GroupWindowAggregateTranspose() BoolFlag { + return groupWindowAggregateTranspose +} + +var newLabels = MakeBoolFlag( + "New Label Package", + "newLabels", + "Alirie Gray", + false, + Temporary, + false, +) + +// NewLabelPackage - Enables the refactored labels api +func NewLabelPackage() BoolFlag { + return newLabels +} + +var memoryOptimizedFill = MakeBoolFlag( + "Memory Optimized Fill", + "memoryOptimizedFill", + "Query Team", + false, + Temporary, + false, +) + +// MemoryOptimizedFill - Enable the memory optimized fill() +func MemoryOptimizedFill() BoolFlag { + return memoryOptimizedFill +} + +var memoryOptimizedSchemaMutation = MakeBoolFlag( + "Memory Optimized Schema Mutation", + "memoryOptimizedSchemaMutation", + "Query Team", + false, + Temporary, + false, +) + +// MemoryOptimizedSchemaMutation - Enable the memory optimized schema mutation functions +func MemoryOptimizedSchemaMutation() BoolFlag { + return memoryOptimizedSchemaMutation +} + +var queryTracing = MakeBoolFlag( + "Query Tracing", + "queryTracing", + "Query Team", + false, + Permanent, + false, +) + +// QueryTracing - Turn on query tracing for queries that are sampled +func QueryTracing() BoolFlag { + return queryTracing +} + +var bandPlotType = MakeBoolFlag( + "Band Plot Type", + "bandPlotType", + "Monitoring Team", + false, + Temporary, + true, +) + +// BandPlotType - Enables the creation of a band plot in Dashboards +func BandPlotType() BoolFlag { + return bandPlotType +} + +var mosaicGraphType = MakeBoolFlag( + "Mosaic Graph Type", + "mosaicGraphType", + "Monitoring Team", + false, + Temporary, + true, +) + +// MosaicGraphType - Enables the creation of a mosaic graph in Dashboards +func MosaicGraphType() BoolFlag { + return mosaicGraphType +} + +var notebooks = MakeBoolFlag( + "Notebooks", + "notebooks", + "Monitoring Team", + false, + Temporary, + true, +) + +// Notebooks - Determine if the notebook feature's route and navbar icon are visible to the user +func Notebooks() BoolFlag { + return notebooks +} + +var injectLatestSuccessTime = MakeBoolFlag( + "Inject Latest Success Time", + "injectLatestSuccessTime", + "Compute Team", + false, + Temporary, + false, +) + +// InjectLatestSuccessTime - Inject the latest successful task run timestamp into a Task query extern when executing. +func InjectLatestSuccessTime() BoolFlag { + return injectLatestSuccessTime +} + +var enforceOrgDashboardLimits = MakeBoolFlag( + "Enforce Organization Dashboard Limits", + "enforceOrgDashboardLimits", + "Compute Team", + false, + Temporary, + false, +) + +// EnforceOrganizationDashboardLimits - Enforces the default limit params for the dashboards api when orgs are set +func EnforceOrganizationDashboardLimits() BoolFlag { + return enforceOrgDashboardLimits +} + +var timeFilterFlags = MakeBoolFlag( + "Time Filter Flags", + "timeFilterFlags", + "Compute Team", + false, + Temporary, + true, +) + +// TimeFilterFlags - Filter task run list based on before and after flags +func TimeFilterFlags() BoolFlag { + return timeFilterFlags +} + +var all = []Flag{ + appMetrics, + backendExample, + communityTemplates, + frontendExample, + groupWindowAggregateTranspose, + newLabels, + memoryOptimizedFill, + memoryOptimizedSchemaMutation, + queryTracing, + bandPlotType, + mosaicGraphType, + notebooks, + injectLatestSuccessTime, + enforceOrgDashboardLimits, + timeFilterFlags, +} + +var byKey = map[string]Flag{ + "appMetrics": appMetrics, + "backendExample": backendExample, + "communityTemplates": communityTemplates, + "frontendExample": frontendExample, + "groupWindowAggregateTranspose": groupWindowAggregateTranspose, + "newLabels": newLabels, + "memoryOptimizedFill": memoryOptimizedFill, + "memoryOptimizedSchemaMutation": memoryOptimizedSchemaMutation, + "queryTracing": queryTracing, + "bandPlotType": bandPlotType, + "mosaicGraphType": mosaicGraphType, + "notebooks": notebooks, + "injectLatestSuccessTime": injectLatestSuccessTime, + "enforceOrgDashboardLimits": enforceOrgDashboardLimits, + "timeFilterFlags": timeFilterFlags, +} diff --git a/kit/feature/middleware.go b/kit/feature/middleware.go new file mode 100644 index 0000000000..20d374c746 --- /dev/null +++ b/kit/feature/middleware.go @@ -0,0 +1,70 @@ +package feature + +import ( + "context" + "encoding/json" + "net/http" + + "go.uber.org/zap" +) + +// Handler is a middleware that annotates the context with a map of computed feature flags. +// To accurately compute identity-scoped flags, this middleware should be executed after any +// authorization middleware has annotated the request context with an authorizer. +type Handler struct { + log *zap.Logger + next http.Handler + flagger Flagger + flags []Flag +} + +// NewHandler returns a configured feature flag middleware that will annotate request context +// with a computed map of the given flags using the provided Flagger. +func NewHandler(log *zap.Logger, flagger Flagger, flags []Flag, next http.Handler) http.Handler { + return &Handler{ + log: log, + next: next, + flagger: flagger, + flags: flags, + } +} + +// ServeHTTP annotates the request context with a map of computed feature flags before +// continuing to serve the request. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, err := Annotate(r.Context(), h.flagger, h.flags...) + if err != nil { + h.log.Warn("Unable to annotate context with feature flags", zap.Error(err)) + } else { + r = r.WithContext(ctx) + } + + if h.next != nil { + h.next.ServeHTTP(w, r) + } +} + +// HTTPErrorHandler is an influxdb.HTTPErrorHandler. It's defined here instead +// of referencing the other interface type, because we want to try our best to +// avoid cyclical dependencies when feature package is used throughout the +// codebase. +type HTTPErrorHandler interface { + HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter) +} + +// NewFlagsHandler returns a handler that returns the map of computed feature flags on the request context. +func NewFlagsHandler(errorHandler HTTPErrorHandler, byKey ByKeyFn) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + + var ( + ctx = r.Context() + flags = ExposedFlagsFromContext(ctx, byKey) + ) + if err := json.NewEncoder(w).Encode(flags); err != nil { + errorHandler.HandleHTTPError(ctx, err, w) + } + } + return http.HandlerFunc(fn) +} diff --git a/kit/feature/middleware_test.go b/kit/feature/middleware_test.go new file mode 100644 index 0000000000..286cb1a70a --- /dev/null +++ b/kit/feature/middleware_test.go @@ -0,0 +1,47 @@ +package feature_test + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/influxdata/influxdb/kit/feature" + "go.uber.org/zap/zaptest" +) + +func Test_Handler(t *testing.T) { + var ( + w = &httptest.ResponseRecorder{} + r = httptest.NewRequest(http.MethodGet, "http://nowhere.test", new(bytes.Buffer)). + WithContext(context.Background()) + + original = r.Context() + ) + + handler := &checkHandler{t: t, f: func(t *testing.T, r *http.Request) { + if r.Context() == original { + t.Error("expected annotated context") + } + }} + + subject := feature.NewHandler(zaptest.NewLogger(t), feature.DefaultFlagger(), feature.Flags(), handler) + + subject.ServeHTTP(w, r) + + if !handler.called { + t.Error("expected handler to be called") + } +} + +type checkHandler struct { + t *testing.T + f func(t *testing.T, r *http.Request) + called bool +} + +func (h *checkHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + h.called = true + h.f(h.t, r) +} diff --git a/kit/feature/override/override.go b/kit/feature/override/override.go new file mode 100644 index 0000000000..6a41c22e92 --- /dev/null +++ b/kit/feature/override/override.go @@ -0,0 +1,83 @@ +package override + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/influxdata/influxdb/kit/feature" +) + +// Flagger can override default flag values. +type Flagger struct { + overrides map[string]string + byKey feature.ByKeyFn +} + +// Make a Flagger that returns defaults with any overrides parsed from the string. +func Make(overrides map[string]string, byKey feature.ByKeyFn) (Flagger, error) { + if byKey == nil { + byKey = feature.ByKey + } + + // Check all provided override keys correspond to an existing Flag. + var missing []string + for k := range overrides { + if _, found := byKey(k); !found { + missing = append(missing, k) + } + } + if len(missing) > 0 { + return Flagger{}, fmt.Errorf("configured overrides for non-existent flags: %s", strings.Join(missing, ",")) + } + + return Flagger{ + overrides: overrides, + byKey: byKey, + }, nil +} + +// Flags returns a map of default values with overrides applied. It never returns an error. +func (f Flagger) Flags(_ context.Context, flags ...feature.Flag) (map[string]interface{}, error) { + if len(flags) == 0 { + flags = feature.Flags() + } + + m := make(map[string]interface{}, len(flags)) + for _, flag := range flags { + if s, overridden := f.overrides[flag.Key()]; overridden { + iface, err := f.coerce(s, flag) + if err != nil { + return nil, err + } + m[flag.Key()] = iface + } else { + m[flag.Key()] = flag.Default() + } + } + + return m, nil +} + +func (f Flagger) coerce(s string, flag feature.Flag) (iface interface{}, err error) { + if base, ok := flag.(feature.Base); ok { + flag, _ = f.byKey(base.Key()) + } + + switch flag.(type) { + case feature.BoolFlag: + iface, err = strconv.ParseBool(s) + case feature.IntFlag: + iface, err = strconv.Atoi(s) + case feature.FloatFlag: + iface, err = strconv.ParseFloat(s, 64) + default: + iface = s + } + + if err != nil { + return nil, fmt.Errorf("coercing string %q based on flag type %T: %v", s, flag, err) + } + return +} diff --git a/kit/feature/override/override_test.go b/kit/feature/override/override_test.go new file mode 100644 index 0000000000..cc5aa0ad86 --- /dev/null +++ b/kit/feature/override/override_test.go @@ -0,0 +1,183 @@ +package override + +import ( + "context" + "testing" + + "github.com/influxdata/influxdb/kit/feature" +) + +func TestFlagger(t *testing.T) { + + cases := []struct { + name string + env map[string]string + defaults []feature.Flag + expected map[string]interface{} + expectMakeErr bool + expectFlagsErr bool + byKey feature.ByKeyFn + }{ + { + name: "enabled happy path filtering", + env: map[string]string{ + "flag1": "new1", + "flag3": "new3", + }, + defaults: []feature.Flag{ + newFlag("flag0", "original0"), + newFlag("flag1", "original1"), + newFlag("flag2", "original2"), + newFlag("flag3", "original3"), + newFlag("flag4", "original4"), + }, + byKey: newByKey(map[string]feature.Flag{ + "flag0": newFlag("flag0", "original0"), + "flag1": newFlag("flag1", "original1"), + "flag2": newFlag("flag2", "original2"), + "flag3": newFlag("flag3", "original3"), + "flag4": newFlag("flag4", "original4"), + }), + expected: map[string]interface{}{ + "flag0": "original0", + "flag1": "new1", + "flag2": "original2", + "flag3": "new3", + "flag4": "original4", + }, + }, + { + name: "enabled happy path types", + env: map[string]string{ + "intflag": "43", + "floatflag": "43.43", + "boolflag": "true", + }, + defaults: []feature.Flag{ + newFlag("intflag", 42), + newFlag("floatflag", 42.42), + newFlag("boolflag", false), + }, + byKey: newByKey(map[string]feature.Flag{ + "intflag": newFlag("intflag", 42), + "floatflag": newFlag("floatflag", 43.43), + "boolflag": newFlag("boolflag", false), + }), + expected: map[string]interface{}{ + "intflag": 43, + "floatflag": 43.43, + "boolflag": true, + }, + }, + { + name: "type coerce error", + env: map[string]string{ + "key": "not_an_int", + }, + defaults: []feature.Flag{ + newFlag("key", 42), + }, + byKey: newByKey(map[string]feature.Flag{ + "key": newFlag("key", 42), + }), + expectFlagsErr: true, + }, + { + name: "typed base flags", + env: map[string]string{ + "flag1": "411", + "flag2": "new2", + "flag3": "true", + }, + defaults: []feature.Flag{ + newBaseFlag("flag0", "original0"), + newBaseFlag("flag1", 41), + newBaseFlag("flag2", "original2"), + newBaseFlag("flag3", false), + newBaseFlag("flag4", "original4"), + }, + byKey: newByKey(map[string]feature.Flag{ + "flag0": newFlag("flag0", "original0"), + "flag1": newFlag("flag1", 41), + "flag2": newFlag("flag2", "original2"), + "flag3": newFlag("flag3", false), + "flag4": newFlag("flag4", "original4"), + }), + expected: map[string]interface{}{ + "flag0": "original0", + "flag1": 411, + "flag2": "new2", + "flag3": true, + "flag4": "original4", + }, + }, + { + name: "override for non-existent flag", + env: map[string]string{ + "dne": "foobar", + }, + defaults: []feature.Flag{ + newBaseFlag("key", "value"), + }, + byKey: newByKey(map[string]feature.Flag{ + "key": newFlag("key", "value"), + }), + expectMakeErr: true, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + subject, err := Make(test.env, test.byKey) + if err != nil { + if test.expectMakeErr { + return + } + t.Fatalf("unexpected error making Flagger: %v", err) + } + + computed, err := subject.Flags(context.Background(), test.defaults...) + if err != nil { + if test.expectFlagsErr { + return + } + t.Fatalf("unexpected error calling Flags: %v", err) + } + + if len(computed) != len(test.expected) { + t.Fatalf("incorrect number of flags computed: expected %d, got %d", len(test.expected), len(computed)) + } + + // check for extra or incorrect keys + for k, v := range computed { + if xv, found := test.expected[k]; !found { + t.Errorf("unexpected key %s", k) + } else if v != xv { + t.Errorf("incorrect value for key %s: expected %v [%T], got %v [%T]", k, xv, xv, v, v) + } + } + + // check for missing keys + for k := range test.expected { + if _, found := computed[k]; !found { + t.Errorf("missing expected key %s", k) + } + } + }) + } +} + +func newFlag(key string, defaultValue interface{}) feature.Flag { + return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false) +} + +func newBaseFlag(key string, defaultValue interface{}) feature.Base { + return feature.MakeBase(key, key, "", defaultValue, feature.Temporary, false) +} + +func newByKey(m map[string]feature.Flag) feature.ByKeyFn { + return func(k string) (feature.Flag, bool) { + v, found := m[k] + return v, found + } +} diff --git a/kit/io/limited_read_closer.go b/kit/io/limited_read_closer.go new file mode 100644 index 0000000000..71f1ff14f7 --- /dev/null +++ b/kit/io/limited_read_closer.go @@ -0,0 +1,59 @@ +package io + +import ( + "errors" + "io" +) + +var ErrReadLimitExceeded = errors.New("read limit exceeded") + +// LimitedReadCloser wraps an io.ReadCloser in limiting behavior using +// io.LimitedReader. It allows us to obtain the limit error at the time of close +// instead of just when writing. +type LimitedReadCloser struct { + R io.ReadCloser // underlying reader + N int64 // max bytes remaining + err error + closed bool + limitExceeded bool +} + +// NewLimitedReadCloser returns a new LimitedReadCloser. +func NewLimitedReadCloser(r io.ReadCloser, n int64) *LimitedReadCloser { + return &LimitedReadCloser{ + R: r, + N: n, + } +} + +func (l *LimitedReadCloser) Read(p []byte) (n int, err error) { + if l.N <= 0 { + l.limitExceeded = true + return 0, io.EOF + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + +// Close returns an ErrReadLimitExceeded when the wrapped reader exceeds the set +// limit for number of bytes. This is safe to call more than once but not +// concurrently. +func (l *LimitedReadCloser) Close() (err error) { + if l.limitExceeded { + l.err = ErrReadLimitExceeded + } + if l.closed { + // Close has already been called. + return l.err + } + if err := l.R.Close(); err != nil && l.err == nil { + l.err = err + } + // Prevent l.closer.Close from being called again. + l.closed = true + return l.err +} diff --git a/kit/io/limited_read_closer_test.go b/kit/io/limited_read_closer_test.go new file mode 100644 index 0000000000..357da74646 --- /dev/null +++ b/kit/io/limited_read_closer_test.go @@ -0,0 +1,87 @@ +package io + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLimitedReadCloser_Exceeded(t *testing.T) { + b := &closer{Reader: bytes.NewBufferString("howdy")} + rc := NewLimitedReadCloser(b, 3) + + out, err := ioutil.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, []byte("how"), out) + assert.Equal(t, ErrReadLimitExceeded, rc.Close()) +} + +func TestLimitedReadCloser_Happy(t *testing.T) { + b := &closer{Reader: bytes.NewBufferString("ho")} + rc := NewLimitedReadCloser(b, 2) + + out, err := ioutil.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, []byte("ho"), out) + assert.Nil(t, err) +} + +func TestLimitedReadCloseWithErrorAndLimitExceeded(t *testing.T) { + b := &closer{ + Reader: bytes.NewBufferString("howdy"), + err: errors.New("some error"), + } + rc := NewLimitedReadCloser(b, 3) + + out, err := ioutil.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, []byte("how"), out) + // LimitExceeded error trumps the close error. + assert.Equal(t, ErrReadLimitExceeded, rc.Close()) +} + +func TestLimitedReadCloseWithError(t *testing.T) { + closeErr := errors.New("some error") + b := &closer{ + Reader: bytes.NewBufferString("howdy"), + err: closeErr, + } + rc := NewLimitedReadCloser(b, 10) + + out, err := ioutil.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, []byte("howdy"), out) + assert.Equal(t, closeErr, rc.Close()) +} + +func TestMultipleCloseOnlyClosesOnce(t *testing.T) { + closeErr := errors.New("some error") + b := &closer{ + Reader: bytes.NewBufferString("howdy"), + err: closeErr, + } + rc := NewLimitedReadCloser(b, 10) + + out, err := ioutil.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, []byte("howdy"), out) + assert.Equal(t, closeErr, rc.Close()) + assert.Equal(t, closeErr, rc.Close()) + assert.Equal(t, 1, b.closeCount) +} + +type closer struct { + io.Reader + err error + closeCount int +} + +func (c *closer) Close() error { + c.closeCount++ + return c.err +} diff --git a/kit/metric/client.go b/kit/metric/client.go new file mode 100644 index 0000000000..7b57094ef2 --- /dev/null +++ b/kit/metric/client.go @@ -0,0 +1,152 @@ +package metric + +import ( + "time" + + "github.com/influxdata/influxdb/kit/platform/errors" + + "github.com/prometheus/client_golang/prometheus" +) + +// REDClient is a metrics client for collection RED metrics. +type REDClient struct { + metrics []metricCollector +} + +// New creates a new REDClient. +func New(reg prometheus.Registerer, service string, opts ...ClientOptFn) *REDClient { + opt := metricOpts{ + namespace: "service", + service: service, + counterMetrics: map[string]VecOpts{ + "call_total": { + Help: "Number of calls", + LabelNames: []string{"method"}, + CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) { + vec.With(prometheus.Labels{"method": o.Method}).Inc() + }, + }, + "error_total": { + Help: "Number of errors encountered", + LabelNames: []string{"method", "code"}, + CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) { + if o.Err != nil { + vec.With(prometheus.Labels{ + "method": o.Method, + "code": errors.ErrorCode(o.Err), + }).Inc() + } + }, + }, + }, + histogramMetrics: map[string]VecOpts{ + "duration": { + Help: "Duration of calls", + LabelNames: []string{"method"}, + HistogramFn: func(vec *prometheus.HistogramVec, o CollectFnOpts) { + vec. + With(prometheus.Labels{"method": o.Method}). + Observe(time.Since(o.Start).Seconds()) + }, + }, + }, + } + for _, o := range opts { + o(&opt) + } + + client := new(REDClient) + for metricName, vecOpts := range opt.counterMetrics { + client.metrics = append(client.metrics, &counter{ + fn: vecOpts.CounterFn, + CounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: opt.namespace, + Subsystem: opt.serviceName(), + Name: metricName, + Help: vecOpts.Help, + }, vecOpts.LabelNames), + }) + } + + for metricName, vecOpts := range opt.histogramMetrics { + client.metrics = append(client.metrics, &histogram{ + fn: vecOpts.HistogramFn, + HistogramVec: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: opt.namespace, + Subsystem: opt.serviceName(), + Name: metricName, + Help: vecOpts.Help, + }, vecOpts.LabelNames), + }) + } + + reg.MustRegister(client.collectors()...) + + return client +} + +type RecordFn func(err error, opts ...func(opts *CollectFnOpts)) error + +// RecordAdditional provides an extension to the base method, err data provided +// to the metrics. +func RecordAdditional(props map[string]interface{}) func(opts *CollectFnOpts) { + return func(opts *CollectFnOpts) { + opts.AdditionalProps = props + } +} + +// Record returns a record fn that is called on any given return err. If an error is encountered +// it will register the err metric. The err is never altered. +func (c *REDClient) Record(method string) RecordFn { + start := time.Now() + return func(err error, opts ...func(opts *CollectFnOpts)) error { + opt := CollectFnOpts{ + Method: method, + Start: start, + Err: err, + } + for _, o := range opts { + o(&opt) + } + + for _, metric := range c.metrics { + metric.collect(opt) + } + + return err + } +} + +func (c *REDClient) collectors() []prometheus.Collector { + var collectors []prometheus.Collector + for _, metric := range c.metrics { + collectors = append(collectors, metric) + } + return collectors +} + +type metricCollector interface { + prometheus.Collector + + collect(o CollectFnOpts) +} + +type counter struct { + *prometheus.CounterVec + + fn CounterFn +} + +func (c *counter) collect(o CollectFnOpts) { + c.fn(c.CounterVec, o) +} + +type histogram struct { + *prometheus.HistogramVec + + fn HistogramFn +} + +func (h *histogram) collect(o CollectFnOpts) { + h.fn(h.HistogramVec, o) +} diff --git a/kit/metric/metrics_options.go b/kit/metric/metrics_options.go new file mode 100644 index 0000000000..b9bbd7d883 --- /dev/null +++ b/kit/metric/metrics_options.go @@ -0,0 +1,84 @@ +package metric + +import ( + "fmt" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +type ( + // CollectFnOpts provides arguments to the collect operation of a metric. + CollectFnOpts struct { + Method string + Start time.Time + Err error + AdditionalProps map[string]interface{} + } + + CounterFn func(vec *prometheus.CounterVec, o CollectFnOpts) + + HistogramFn func(vec *prometheus.HistogramVec, o CollectFnOpts) + + // VecOpts expands on the + VecOpts struct { + Name string + Help string + LabelNames []string + + CounterFn CounterFn + HistogramFn HistogramFn + } +) + +type metricOpts struct { + namespace string + service string + serviceSuffix string + counterMetrics map[string]VecOpts + histogramMetrics map[string]VecOpts +} + +func (o metricOpts) serviceName() string { + if o.serviceSuffix != "" { + return fmt.Sprintf("%s_%s", o.service, o.serviceSuffix) + } + return o.service +} + +// ClientOptFn is an option used by a metric middleware. +type ClientOptFn func(*metricOpts) + +// WithVec sets a new counter vector to be collected. +func WithVec(opts VecOpts) ClientOptFn { + return func(o *metricOpts) { + if opts.CounterFn != nil { + if o.counterMetrics == nil { + o.counterMetrics = make(map[string]VecOpts) + } + o.counterMetrics[opts.Name] = opts + } + } +} + +// WithSuffix returns a metric option that applies a suffix to the service name of the metric. +func WithSuffix(suffix string) ClientOptFn { + return func(opts *metricOpts) { + opts.serviceSuffix = suffix + } +} + +func ApplyMetricOpts(opts ...ClientOptFn) *metricOpts { + o := metricOpts{} + for _, opt := range opts { + opt(&o) + } + return &o +} + +func (o *metricOpts) ApplySuffix(prefix string) string { + if o.serviceSuffix != "" { + return fmt.Sprintf("%s_%s", prefix, o.serviceSuffix) + } + return prefix +} diff --git a/kit/platform/errors/error.go b/kit/platform/errors/error.go new file mode 100644 index 0000000000..fb2f5eec7a --- /dev/null +++ b/kit/platform/errors/error.go @@ -0,0 +1,9 @@ +package errors + +// ChronografError is a domain error encountered while processing chronograf requests. +type ChronografError string + +// ChronografError returns the string of an error. +func (e ChronografError) Error() string { + return string(e) +} diff --git a/kit/platform/errors/errors.go b/kit/platform/errors/errors.go new file mode 100644 index 0000000000..8a3a1a3d98 --- /dev/null +++ b/kit/platform/errors/errors.go @@ -0,0 +1,264 @@ +package errors + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" +) + +// Some error code constant, ideally we want define common platform codes here +// projects on use platform's error, should have their own central place like this. +// Any time this set of constants changes, you must also update the swagger for Error.properties.code.enum. +const ( + EInternal = "internal error" + ENotImplemented = "not implemented" + ENotFound = "not found" + EConflict = "conflict" // action cannot be performed + EInvalid = "invalid" // validation failed + EUnprocessableEntity = "unprocessable entity" // data type is correct, but out of range + EEmptyValue = "empty value" + EUnavailable = "unavailable" + EForbidden = "forbidden" + ETooManyRequests = "too many requests" + EUnauthorized = "unauthorized" + EMethodNotAllowed = "method not allowed" + ETooLarge = "request too large" +) + +// Error is the error struct of platform. +// +// Errors may have error codes, human-readable messages, +// and a logical stack trace. +// +// The Code targets automated handlers so that recovery can occur. +// Msg is used by the system operator to help diagnose and fix the problem. +// Op and Err chain errors together in a logical stack trace to +// further help operators. +// +// To create a simple error, +// &Error{ +// Code:ENotFound, +// } +// To show where the error happens, add Op. +// &Error{ +// Code: ENotFound, +// Op: "bolt.FindUserByID" +// } +// To show an error with a unpredictable value, add the value in Msg. +// &Error{ +// Code: EConflict, +// Message: fmt.Sprintf("organization with name %s already exist", aName), +// } +// To show an error wrapped with another error. +// &Error{ +// Code:EInternal, +// Err: err, +// }. +type Error struct { + Code string + Msg string + Op string + Err error +} + +// NewError returns an instance of an error. +func NewError(options ...func(*Error)) *Error { + err := &Error{} + for _, o := range options { + o(err) + } + + return err +} + +// WithErrorErr sets the err on the error. +func WithErrorErr(err error) func(*Error) { + return func(e *Error) { + e.Err = err + } +} + +// WithErrorCode sets the code on the error. +func WithErrorCode(code string) func(*Error) { + return func(e *Error) { + e.Code = code + } +} + +// WithErrorMsg sets the message on the error. +func WithErrorMsg(msg string) func(*Error) { + return func(e *Error) { + e.Msg = msg + } +} + +// WithErrorOp sets the message on the error. +func WithErrorOp(op string) func(*Error) { + return func(e *Error) { + e.Op = op + } +} + +// Error implements the error interface by writing out the recursive messages. +func (e *Error) Error() string { + if e.Msg != "" && e.Err != nil { + var b strings.Builder + b.WriteString(e.Msg) + b.WriteString(": ") + b.WriteString(e.Err.Error()) + return b.String() + } else if e.Msg != "" { + return e.Msg + } else if e.Err != nil { + return e.Err.Error() + } + return fmt.Sprintf("<%s>", e.Code) +} + +// ErrorCode returns the code of the root error, if available; otherwise returns EINTERNAL. +func ErrorCode(err error) string { + if err == nil { + return "" + } + + e, ok := err.(*Error) + if !ok { + return EInternal + } + + if e == nil { + return "" + } + + if e.Code != "" { + return e.Code + } + + if e.Err != nil { + return ErrorCode(e.Err) + } + + return EInternal +} + +// ErrorOp returns the op of the error, if available; otherwise return empty string. +func ErrorOp(err error) string { + if err == nil { + return "" + } + + e, ok := err.(*Error) + if !ok { + return "" + } + + if e == nil { + return "" + } + + if e.Op != "" { + return e.Op + } + + if e.Err != nil { + return ErrorOp(e.Err) + } + + return "" +} + +// ErrorMessage returns the human-readable message of the error, if available. +// Otherwise returns a generic error message. +func ErrorMessage(err error) string { + if err == nil { + return "" + } + + e, ok := err.(*Error) + if !ok { + return "An internal error has occurred." + } + + if e == nil { + return "" + } + + if e.Msg != "" { + return e.Msg + } + + if e.Err != nil { + return ErrorMessage(e.Err) + } + + return "An internal error has occurred." +} + +// errEncode an JSON encoding helper that is needed to handle the recursive stack of errors. +type errEncode struct { + Code string `json:"code"` // Code is the machine-readable error code. + Msg string `json:"message,omitempty"` // Msg is a human-readable message. + Op string `json:"op,omitempty"` // Op describes the logical code operation during error. + Err interface{} `json:"error,omitempty"` // Err is a stack of additional errors. +} + +// MarshalJSON recursively marshals the stack of Err. +func (e *Error) MarshalJSON() (result []byte, err error) { + ee := errEncode{ + Code: e.Code, + Msg: e.Msg, + Op: e.Op, + } + if e.Err != nil { + if _, ok := e.Err.(*Error); ok { + _, err := e.Err.(*Error).MarshalJSON() + if err != nil { + return result, err + } + ee.Err = e.Err + } else { + ee.Err = e.Err.Error() + } + } + return json.Marshal(ee) +} + +// UnmarshalJSON recursively unmarshals the error stack. +func (e *Error) UnmarshalJSON(b []byte) (err error) { + ee := new(errEncode) + err = json.Unmarshal(b, ee) + e.Code = ee.Code + e.Msg = ee.Msg + e.Op = ee.Op + e.Err = decodeInternalError(ee.Err) + return err +} + +func decodeInternalError(target interface{}) error { + if errStr, ok := target.(string); ok { + return errors.New(errStr) + } + if internalErrMap, ok := target.(map[string]interface{}); ok { + internalErr := new(Error) + if code, ok := internalErrMap["code"].(string); ok { + internalErr.Code = code + } + if msg, ok := internalErrMap["message"].(string); ok { + internalErr.Msg = msg + } + if op, ok := internalErrMap["op"].(string); ok { + internalErr.Op = op + } + internalErr.Err = decodeInternalError(internalErrMap["error"]) + return internalErr + } + return nil +} + +// HTTPErrorHandler is the interface to handle http error. +type HTTPErrorHandler interface { + HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter) +} diff --git a/kit/platform/errors/errors.md b/kit/platform/errors/errors.md new file mode 100644 index 0000000000..73c2bd06c1 --- /dev/null +++ b/kit/platform/errors/errors.md @@ -0,0 +1,86 @@ +# errors.go + +This is inspired from Ben Johnson's blog post [Failure is Your Domain](https://middlemost.com/failure-is-your-domain/) + +## The Error struct + +```go + + type Error struct { + Code string + Msg string + Op string + Err error + } +``` + + * Code is the machine readable code, for reference purpose. All the codes should be a constant string. For example. `const ENotFound = "source not found"`. + + * Msg is the human readable message for end user. For example, `Your credit card is declined.` + + * Op is the logical Operator, should be a constant defined inside the function. For example: "bolt.UserCreate". + + * Err is the embed error. You may embed either a third party error or and platform.Error. + +## Use Case Example + +We implement the following interface + +```go + + type OrganizationService interface { + FindOrganizationByID(ctx context.Context, id ID) (*Organization, error) + } + + func (c *Client)FindOrganizationByID(ctx context.Context, id platform.ID) (*platform.Organization, error) { + var o *platform.Organization + const op = "bolt.FindOrganizationByID" + err := c.db.View(func(tx *bolt.Tx) error { + org, err := c.findOrganizationByID(ctx, tx, id) + if err != nil { + return err + } + o = org + return nil + }) + + if err != nil { + return nil, &platform.Error{ + Code: platform.ENotFound, + Op: op, + Err: err, + } + } + return o, nil + } +``` + +To check the error code + +```go + + if platform.ErrorCode(err) == platform.ENotFound { + ... + } +``` + +To serialize the error + +```go + + b, err := json.Marshal(err) +``` + +To deserialize the error + +```go + + e := new(platform.Error) + err := json.Unmarshal(b, e) +``` + + + + + + diff --git a/kit/platform/errors/errors_test.go b/kit/platform/errors/errors_test.go new file mode 100644 index 0000000000..8122b84235 --- /dev/null +++ b/kit/platform/errors/errors_test.go @@ -0,0 +1,300 @@ +package errors_test + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + errors2 "github.com/influxdata/influxdb/kit/platform/errors" +) + +const EFailedToGetStorageHost = "failed to get the storage host" + +func TestErrorMsg(t *testing.T) { + cases := []struct { + name string + err error + msg string + }{ + { + name: "simple error", + err: &errors2.Error{Code: errors2.ENotFound}, + msg: "", + }, + { + name: "with message", + err: &errors2.Error{ + Code: errors2.ENotFound, + Msg: "bucket not found", + }, + msg: "bucket not found", + }, + { + name: "with a third party error and no message", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Err: errors.New("empty value"), + }, + msg: "empty value", + }, + { + name: "with a third party error and a message", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Msg: "failed to get storage hosts", + Err: errors.New("empty value"), + }, + msg: "failed to get storage hosts: empty value", + }, + { + name: "with an internal error and no message", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Err: &errors2.Error{ + Code: errors2.EEmptyValue, + Msg: "empty value", + }, + }, + msg: "empty value", + }, + { + name: "with an internal error and a message", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Msg: "failed to get storage hosts", + Err: &errors2.Error{ + Code: errors2.EEmptyValue, + Msg: "empty value", + }, + }, + msg: "failed to get storage hosts: empty value", + }, + } + for _, c := range cases { + if c.msg != c.err.Error() { + t.Errorf("%s failed, want %s, got %s", c.name, c.msg, c.err.Error()) + } + } +} + +func TestErrorMessage(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + { + name: "nil error", + }, + { + name: "nil error of type *platform.Error", + err: (*errors2.Error)(nil), + }, + { + name: "simple error", + err: &errors2.Error{Msg: "simple error"}, + want: "simple error", + }, + { + name: "embedded error", + err: &errors2.Error{Err: &errors2.Error{Msg: "embedded error"}}, + want: "embedded error", + }, + { + name: "default error", + err: errors.New("s"), + want: "An internal error has occurred.", + }, + } + for _, c := range cases { + if result := errors2.ErrorMessage(c.err); c.want != result { + t.Errorf("%s failed, want %s, got %s", c.name, c.want, result) + } + } +} + +func TestErrorOp(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + { + name: "nil error", + }, + { + name: "nil error of type *platform.Error", + err: (*errors2.Error)(nil), + }, + { + name: "simple error", + err: &errors2.Error{Op: "op1"}, + want: "op1", + }, + { + name: "embedded error", + err: &errors2.Error{Op: "op1", Err: &errors2.Error{Code: errors2.EInvalid}}, + want: "op1", + }, + { + name: "embedded error without op in root level", + err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid, Op: "op2"}}, + want: "op2", + }, + { + name: "default error", + err: errors.New("s"), + want: "", + }, + } + for _, c := range cases { + if result := errors2.ErrorOp(c.err); c.want != result { + t.Errorf("%s failed, want %s, got %s", c.name, c.want, result) + } + } +} +func TestErrorCode(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + { + name: "nil error", + }, + { + name: "nil error of type *platform.Error", + err: (*errors2.Error)(nil), + }, + { + name: "simple error", + err: &errors2.Error{Code: errors2.ENotFound}, + want: errors2.ENotFound, + }, + { + name: "embedded error", + err: &errors2.Error{Code: errors2.ENotFound, Err: &errors2.Error{Code: errors2.EInvalid}}, + want: errors2.ENotFound, + }, + { + name: "embedded error with root level code", + err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid}}, + want: errors2.EInvalid, + }, + { + name: "default error", + err: errors.New("s"), + want: errors2.EInternal, + }, + } + for _, c := range cases { + if result := errors2.ErrorCode(c.err); c.want != result { + t.Errorf("%s failed, want %s, got %s", c.name, c.want, result) + } + } +} + +func TestJSON(t *testing.T) { + cases := []struct { + name string + err *errors2.Error + encoded string + }{ + { + name: "simple error", + err: &errors2.Error{Code: errors2.ENotFound}, + encoded: `{"code":"not found"}`, + }, + { + name: "with op", + err: &errors2.Error{ + Code: errors2.ENotFound, + Op: "bolt.FindAuthorizationByID", + }, + encoded: `{"code":"not found","op":"bolt.FindAuthorizationByID"}`, + }, + { + name: "with op and value", + err: &errors2.Error{ + Code: errors2.ENotFound, + Op: "bolt/FindAuthorizationByID", + Msg: fmt.Sprintf("with ID %d", 323), + }, + encoded: `{"code":"not found","message":"with ID 323","op":"bolt/FindAuthorizationByID"}`, + }, + { + name: "with a third party error", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Op: "cmd/fluxd.injectDeps", + Err: errors.New("empty value"), + }, + encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":"empty value"}`, + }, + { + name: "with a internal error", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Op: "cmd/fluxd.injectDeps", + Err: &errors2.Error{Code: errors2.EEmptyValue, Op: "cmd/fluxd.getStrList"}, + }, + encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"empty value","op":"cmd/fluxd.getStrList"}}`, + }, + { + name: "with a deep internal error", + err: &errors2.Error{ + Code: EFailedToGetStorageHost, + Op: "cmd/fluxd.injectDeps", + Err: &errors2.Error{ + Code: errors2.EInvalid, + Op: "cmd/fluxd.getStrList", + Err: &errors2.Error{ + Code: errors2.EEmptyValue, + Err: errors.New("an err"), + }, + }, + }, + encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"invalid","op":"cmd/fluxd.getStrList","error":{"code":"empty value","error":"an err"}}}`, + }, + } + for _, c := range cases { + result, err := json.Marshal(c.err) + // encode testing + if err != nil { + t.Errorf("%s encode failed, want err: %v, should be nil", c.name, err) + } + + if string(result) != c.encoded { + t.Errorf("%s encode failed, want result: %s, got %s", c.name, c.encoded, string(result)) + } + // decode testing + got := new(errors2.Error) + err = json.Unmarshal(result, got) + if err != nil { + t.Errorf("%s decode failed, want err: %v, should be nil", c.name, err) + } + decodeEqual(t, c.err, got, "decode: "+c.name) + } +} + +func decodeEqual(t *testing.T, want, result *errors2.Error, caseName string) { + if want.Code != result.Code { + t.Errorf("%s code failed, want %s, got %s", caseName, want.Code, result.Code) + } + if want.Op != result.Op { + t.Errorf("%s op failed, want %s, got %s", caseName, want.Op, result.Op) + } + if want.Msg != result.Msg { + t.Errorf("%s msg failed, want %s, got %s", caseName, want.Msg, result.Msg) + } + if want.Err != nil { + if _, ok := want.Err.(*errors2.Error); ok { + decodeEqual(t, want.Err.(*errors2.Error), result.Err.(*errors2.Error), caseName) + } else { + if want.Err.Error() != result.Err.Error() { + t.Errorf("%s Err failed, want %s, got %s", caseName, want.Err.Error(), result.Err.Error()) + } + } + } +} diff --git a/kit/platform/id.go b/kit/platform/id.go new file mode 100644 index 0000000000..bb4b63e11b --- /dev/null +++ b/kit/platform/id.go @@ -0,0 +1,145 @@ +package platform + +import ( + "encoding/binary" + "encoding/hex" + "strconv" + "unsafe" + + "github.com/influxdata/influxdb/kit/platform/errors" +) + +// IDLength is the exact length a string (or a byte slice representing it) must have in order to be decoded into a valid ID. +const IDLength = 16 + +var ( + // ErrInvalidID signifies invalid IDs. + ErrInvalidID = &errors.Error{ + Code: errors.EInvalid, + Msg: "invalid ID", + } + + // ErrInvalidIDLength is returned when an ID has the incorrect number of bytes. + ErrInvalidIDLength = &errors.Error{ + Code: errors.EInvalid, + Msg: "id must have a length of 16 bytes", + } +) + +// ErrCorruptID means the ID stored in the Store is corrupt. +func ErrCorruptID(err error) *errors.Error { + return &errors.Error{ + Code: errors.EInvalid, + Msg: "corrupt ID provided", + Err: err, + } +} + +// ID is a unique identifier. +// +// Its zero value is not a valid ID. +type ID uint64 + +// IDGenerator represents a generator for IDs. +type IDGenerator interface { + // ID creates unique byte slice ID. + ID() ID +} + +// IDFromString creates an ID from a given string. +// +// It errors if the input string does not match a valid ID. +func IDFromString(str string) (*ID, error) { + var id ID + err := id.DecodeFromString(str) + if err != nil { + return nil, err + } + return &id, nil +} + +// InvalidID returns a zero ID. +func InvalidID() ID { + return 0 +} + +// Decode parses b as a hex-encoded byte-slice-string. +// +// It errors if the input byte slice does not have the correct length +// or if it contains all zeros. +func (i *ID) Decode(b []byte) error { + if len(b) != IDLength { + return ErrInvalidIDLength + } + + res, err := strconv.ParseUint(unsafeBytesToString(b), 16, 64) + if err != nil { + return ErrInvalidID + } + + if *i = ID(res); !i.Valid() { + return ErrInvalidID + } + return nil +} + +func unsafeBytesToString(in []byte) string { + return *(*string)(unsafe.Pointer(&in)) +} + +// DecodeFromString parses s as a hex-encoded string. +func (i *ID) DecodeFromString(s string) error { + return i.Decode([]byte(s)) +} + +// Encode converts ID to a hex-encoded byte-slice-string. +// +// It errors if the receiving ID holds its zero value. +func (i ID) Encode() ([]byte, error) { + if !i.Valid() { + return nil, ErrInvalidID + } + + b := make([]byte, hex.DecodedLen(IDLength)) + binary.BigEndian.PutUint64(b, uint64(i)) + + dst := make([]byte, hex.EncodedLen(len(b))) + hex.Encode(dst, b) + return dst, nil +} + +// Valid checks whether the receiving ID is a valid one or not. +func (i ID) Valid() bool { + return i != 0 +} + +// String returns the ID as a hex encoded string. +// +// Returns an empty string in the case the ID is invalid. +func (i ID) String() string { + enc, _ := i.Encode() + return string(enc) +} + +// GoString formats the ID the same as the String method. +// Without this, when using the %#v verb, an ID would be printed as a uint64, +// so you would see e.g. 0x2def021097c6000 instead of 02def021097c6000 +// (note the leading 0x, which means the former doesn't show up in searches for the latter). +func (i ID) GoString() string { + return `"` + i.String() + `"` +} + +// MarshalText encodes i as text. +// Providing this method is a fallback for json.Marshal, +// with the added benefit that IDs encoded as map keys will be the expected string encoding, +// rather than the effective fmt.Sprintf("%d", i) that json.Marshal uses by default for integer types. +func (i ID) MarshalText() ([]byte, error) { + return i.Encode() +} + +// UnmarshalText decodes i from a byte slice. +// Providing this method is also a fallback for json.Unmarshal, +// also relevant when IDs are used as map keys. +func (i *ID) UnmarshalText(b []byte) error { + return i.Decode(b) +} diff --git a/kit/platform/id_test.go b/kit/platform/id_test.go new file mode 100644 index 0000000000..045219d2b7 --- /dev/null +++ b/kit/platform/id_test.go @@ -0,0 +1,259 @@ +package platform_test + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/influxdata/influxdb/kit/platform" +) + +func MustIDBase16(s string) platform.ID { + id, err := platform.IDFromString(s) + if err != nil { + panic(err) + } + return *id +} + +func TestIDFromString(t *testing.T) { + tests := []struct { + name string + id string + want platform.ID + wantErr bool + err string + }{ + { + name: "Should be able to decode an all zeros ID", + id: "0000000000000000", + wantErr: true, + err: platform.ErrInvalidID.Error(), + }, + { + name: "Should be able to decode an all f ID", + id: "ffffffffffffffff", + want: MustIDBase16("ffffffffffffffff"), + }, + { + name: "Should be able to decode an ID", + id: "020f755c3c082000", + want: MustIDBase16("020f755c3c082000"), + }, + { + name: "Should not be able to decode a non hex ID", + id: "gggggggggggggggg", + wantErr: true, + err: platform.ErrInvalidID.Error(), + }, + { + name: "Should not be able to decode inputs with length less than 16 bytes", + id: "abc", + wantErr: true, + err: platform.ErrInvalidIDLength.Error(), + }, + { + name: "Should not be able to decode inputs with length greater than 16 bytes", + id: "abcdabcdabcdabcd0", + wantErr: true, + err: platform.ErrInvalidIDLength.Error(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := platform.IDFromString(tt.id) + + // Check negative test cases + if (err != nil) && tt.wantErr { + if tt.err != err.Error() { + t.Errorf("IDFromString() errors out \"%s\", want \"%s\"", err, tt.err) + } + return + } + + // Check positive test cases + if !reflect.DeepEqual(*got, tt.want) && !tt.wantErr { + t.Errorf("IDFromString() outputs %v, want %v", got, tt.want) + } + }) + } +} + +func TestDecodeFromString(t *testing.T) { + var id platform.ID + err := id.DecodeFromString("020f755c3c082000") + if err != nil { + t.Errorf(err.Error()) + } + want := []byte{48, 50, 48, 102, 55, 53, 53, 99, 51, 99, 48, 56, 50, 48, 48, 48} + got, _ := id.Encode() + if !bytes.Equal(want, got) { + t.Errorf("got %s not equal to wanted %s", string(got), string(want)) + } + if id.String() != "020f755c3c082000" { + t.Errorf("expecting string representation to contain the right value") + } + if !id.Valid() { + t.Errorf("expecting ID to be a valid one") + } +} + +func TestEncode(t *testing.T) { + var id platform.ID + if _, err := id.Encode(); err == nil { + t.Errorf("encoding an invalid ID should not be possible") + } + + id.DecodeFromString("5ca1ab1eba5eba11") + want := []byte{53, 99, 97, 49, 97, 98, 49, 101, 98, 97, 53, 101, 98, 97, 49, 49} + got, _ := id.Encode() + if !bytes.Equal(want, got) { + t.Errorf("encoding error") + } + if id.String() != "5ca1ab1eba5eba11" { + t.Errorf("expecting string representation to contain the right value") + } + if !id.Valid() { + t.Errorf("expecting ID to be a valid one") + } +} + +func TestDecodeFromAllZeros(t *testing.T) { + var id platform.ID + err := id.Decode(make([]byte, platform.IDLength)) + if err == nil { + t.Errorf("expecting all zeros ID to not be a valid ID") + } +} + +func TestDecodeFromShorterString(t *testing.T) { + var id platform.ID + err := id.DecodeFromString("020f75") + if err == nil { + t.Errorf("expecting shorter inputs to error") + } + if id.String() != "" { + t.Errorf("expecting invalid ID to be serialized into empty string") + } +} + +func TestDecodeFromLongerString(t *testing.T) { + var id platform.ID + err := id.DecodeFromString("020f755c3c082000aaa") + if err == nil { + t.Errorf("expecting shorter inputs to error") + } + if id.String() != "" { + t.Errorf("expecting invalid ID to be serialized into empty string") + } +} + +func TestDecodeFromEmptyString(t *testing.T) { + var id platform.ID + err := id.DecodeFromString("") + if err == nil { + t.Errorf("expecting empty inputs to error") + } + if id.String() != "" { + t.Errorf("expecting invalid ID to be serialized into empty string") + } +} + +func TestMarshalling(t *testing.T) { + var id0 platform.ID + _, err := json.Marshal(id0) + if err == nil { + t.Errorf("expecting empty ID to not be a valid one") + } + + init := "ca55e77eca55e77e" + id1, err := platform.IDFromString(init) + if err != nil { + t.Errorf(err.Error()) + } + + serialized, err := json.Marshal(id1) + if err != nil { + t.Errorf(err.Error()) + } + + var id2 platform.ID + json.Unmarshal(serialized, &id2) + + bytes1, _ := id1.Encode() + bytes2, _ := id2.Encode() + + if !bytes.Equal(bytes1, bytes2) { + t.Errorf("error marshalling/unmarshalling ID") + } + + // When used as a map key, IDs must use their string encoding. + // If you only implement json.Marshaller, they will be encoded with Go's default integer encoding. + b, err := json.Marshal(map[platform.ID]int{0x1234: 5678}) + if err != nil { + t.Error(err) + } + const exp = `{"0000000000001234":5678}` + if string(b) != exp { + t.Errorf("expected map to json.Marshal as %s; got %s", exp, string(b)) + } + + var idMap map[platform.ID]int + if err := json.Unmarshal(b, &idMap); err != nil { + t.Error(err) + } + if len(idMap) != 1 { + t.Errorf("expected length 1, got %d", len(idMap)) + } + if idMap[0x1234] != 5678 { + t.Errorf("unmarshalled incorrectly; exp 0x1234:5678, got %v", idMap) + } +} + +func TestValid(t *testing.T) { + var id platform.ID + if id.Valid() { + t.Errorf("expecting initial ID to be invalid") + } + + if platform.InvalidID() != 0 { + t.Errorf("expecting invalid ID to return a zero ID, thus invalid") + } +} + +func TestID_GoString(t *testing.T) { + type idGoStringTester struct { + ID platform.ID + } + var x idGoStringTester + + const idString = "02def021097c6000" + if err := x.ID.DecodeFromString(idString); err != nil { + t.Fatal(err) + } + + sharpV := fmt.Sprintf("%#v", x) + want := `platform_test.idGoStringTester{ID:"` + idString + `"}` + if sharpV != want { + t.Fatalf("bad GoString: got %q, want %q", sharpV, want) + } +} + +func BenchmarkIDEncode(b *testing.B) { + var id platform.ID + id.DecodeFromString("5ca1ab1eba5eba11") + b.ResetTimer() + for i := 0; i < b.N; i++ { + b, _ := id.Encode() + _ = b + } +} + +func BenchmarkIDDecode(b *testing.B) { + for i := 0; i < b.N; i++ { + var id platform.ID + id.DecodeFromString("5ca1ab1eba5eba11") + } +} diff --git a/kit/prom/example_test.go b/kit/prom/example_test.go new file mode 100644 index 0000000000..51a6d77a13 --- /dev/null +++ b/kit/prom/example_test.go @@ -0,0 +1,87 @@ +package prom_test + +import ( + "fmt" + "io" + "math/rand" + "net/http" + "time" + + "github.com/influxdata/influxdb/kit/prom" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +// RandomHandler implements an HTTP endpoint that prints a random float, +// and it tracks prometheus metrics about the numbers it returns. +type RandomHandler struct { + // Cumulative sum of values served. + valueCounter prometheus.Counter + + // Total times page served. + serveCounter prometheus.Counter +} + +var ( + _ http.Handler = (*RandomHandler)(nil) + _ prom.PrometheusCollector = (*RandomHandler)(nil) +) + +func NewRandomHandler() *RandomHandler { + return &RandomHandler{ + valueCounter: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "value_counter", + Help: "Cumulative sum of values served.", + }), + serveCounter: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "serve_counter", + Help: "Counter of times page has been served.", + }), + } +} + +// ServeHTTP serves a random float value and updates rh's internal metrics. +func (rh *RandomHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Increment serveCounter every time we serve a page. + rh.serveCounter.Inc() + + n := rand.Float64() + // Track the cumulative values served. + rh.valueCounter.Add(n) + + fmt.Fprintf(w, "%v", n) +} + +// PrometheusCollectors implements prom.PrometheusCollector. +func (rh *RandomHandler) PrometheusCollectors() []prometheus.Collector { + return []prometheus.Collector{rh.valueCounter, rh.serveCounter} +} + +func Example() { + // A collection of endpoints and http.Handlers. + handlers := map[string]http.Handler{ + "/random": NewRandomHandler(), + "/time": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, time.Now().String()) + }), + } + + // Use a local registry, not the global registry in the prometheus package. + reg := prom.NewRegistry(zap.NewNop()) + + // Build the mux out of handlers from above. + mux := http.NewServeMux() + for path, h := range handlers { + mux.Handle(path, h) + + // Only register those handlers which implement prom.PrometheusCollector. + if pc, ok := h.(prom.PrometheusCollector); ok { + reg.MustRegister(pc.PrometheusCollectors()...) + } + } + + // Add metrics to registry. + mux.Handle("/metrics", reg.HTTPHandler()) + + http.ListenAndServe("localhost:8080", mux) +} diff --git a/kit/prom/promtest/promtest.go b/kit/prom/promtest/promtest.go new file mode 100644 index 0000000000..7ed467a15e --- /dev/null +++ b/kit/prom/promtest/promtest.go @@ -0,0 +1,146 @@ +// Package promtest provides helpers for parsing and extracting prometheus metrics. +// These functions are only intended to be called from test files, +// as there is a dependency on the standard library testing package. +package promtest + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" +) + +// FromHTTPResponse parses the prometheus metrics from the given *http.Response. +// It relies on properly set response headers to correctly parse. +// It will unconditionally close the response body. +// +// This is particularly helpful when testing the output of the /metrics endpoint of a service. +// However, for comprehensive testing of metrics, it usually makes more sense to +// add collectors to a registry and call Registry.Gather to get the metrics without involving HTTP. +func FromHTTPResponse(r *http.Response) ([]*dto.MetricFamily, error) { + defer r.Body.Close() + + dec := expfmt.NewDecoder(r.Body, expfmt.ResponseFormat(r.Header)) + var mfs []*dto.MetricFamily + for { + mf := new(dto.MetricFamily) + if err := dec.Decode(mf); err != nil { + if err == io.EOF { + break + } else { + return nil, err + } + } + + mfs = append(mfs, mf) + } + + return mfs, nil +} + +// FindMetric iterates through mfs to find the first metric family matching name. +// If a metric family matches, then the metrics inside the family are searched, +// and the first metric whose labels match the given labels are returned. +// If no matches are found, FindMetric returns nil. +// +// FindMetric assumes that the labels on the metric family are well formed, +// i.e. there are no duplicate label names, and the label values are not empty strings. +func FindMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric { + _, m := findMetric(mfs, name, labels) + return m +} + +// MustFindMetric returns the matching metric, or if no matching metric could be found, +// it calls tb.Log with helpful output of what was actually available, before calling tb.FailNow. +func MustFindMetric(tb testing.TB, mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric { + tb.Helper() + + fam, m := findMetric(mfs, name, labels) + if fam == nil { + tb.Logf("metric family with name %q not found", name) + tb.Log("available names:") + for _, mf := range mfs { + tb.Logf("\t%s", mf.GetName()) + } + tb.FailNow() + return nil // Need an explicit return here for test. + } + + if m == nil { + tb.Logf("found metric family with name %q, but metric with labels %v not found", name, labels) + tb.Logf("available labels on metric family %q:", name) + + for _, m := range fam.Metric { + pairs := make([]string, len(m.Label)) + for i, l := range m.Label { + pairs[i] = fmt.Sprintf("%q: %q", l.GetName(), l.GetValue()) + } + tb.Logf("\t%s", strings.Join(pairs, ", ")) + } + tb.FailNow() + return nil // Need an explicit return here for test. + } + + return m +} + +// findMetric is a helper that returns the matching family and the matching metric. +// The exported FindMetric function specifically only finds the metric, not the family, +// but for test it is more helpful to identify whether the family was matched. +func findMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) (*dto.MetricFamily, *dto.Metric) { + var fam *dto.MetricFamily + + for _, mf := range mfs { + if mf.GetName() == name { + fam = mf + break + } + } + + if fam == nil { + // No family matching the name. + return nil, nil + } + + for _, m := range fam.Metric { + if len(m.Label) != len(labels) { + continue + } + + match := true + for _, l := range m.Label { + if labels[l.GetName()] != l.GetValue() { + match = false + break + } + } + + if !match { + continue + } + + // All labels matched. + return fam, m + } + + // Didn't find a metric whose labels all matched. + return fam, nil +} + +// MustGather calls g.Gather and calls tb.Fatal if there was an error. +func MustGather(tb testing.TB, g prometheus.Gatherer) []*dto.MetricFamily { + tb.Helper() + + mfs, err := g.Gather() + if err != nil { + tb.Fatalf("error while gathering metrics: %v", err) + return nil // Need an explicit return here for test. + } + + return mfs +} diff --git a/kit/prom/promtest/promtest_test.go b/kit/prom/promtest/promtest_test.go new file mode 100644 index 0000000000..7954fa410a --- /dev/null +++ b/kit/prom/promtest/promtest_test.go @@ -0,0 +1,210 @@ +package promtest_test + +import ( + "bytes" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/influxdata/influxdb/kit/prom" + "github.com/influxdata/influxdb/kit/prom/promtest" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "go.uber.org/zap/zaptest" +) + +func helperCollectors() []prometheus.Collector { + myCounter := prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "my", + Subsystem: "random", + Name: "counter", + Help: "Just a random counter.", + }) + myGaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "my", + Subsystem: "random", + Name: "gaugevec", + Help: "Just a random gauge vector.", + }, []string{"label1", "label2"}) + + myCounter.Inc() + myGaugeVec.WithLabelValues("one", "two").Set(3) + + return []prometheus.Collector{myCounter, myGaugeVec} +} + +func TestFindMetric(t *testing.T) { + reg := prom.NewRegistry(zaptest.NewLogger(t)) + reg.MustRegister(helperCollectors()...) + + mfs, err := reg.Gather() + if err != nil { + t.Fatal(err) + } + + c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil) + if got := c.GetCounter().GetValue(); got != 1 { + t.Fatalf("expected counter to be 1, got %v", got) + } + + g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"}) + if got := g.GetGauge().GetValue(); got != 3 { + t.Fatalf("expected gauge to be 3, got %v", got) + } +} + +// fakeT helps us to assert that MustFindMetric calls FailNow when the metric isn't found. +type fakeT struct { + // Embed a T so we don't have to reimplement everything. + // It's fine to leave T nil - fakeT will panic if calling a method we haven't implemented. + *testing.T + + logBuf bytes.Buffer + + failed bool +} + +func (t *fakeT) Helper() {} + +func (t *fakeT) Log(args ...interface{}) { + fmt.Fprint(&t.logBuf, args...) + t.logBuf.WriteString("\n") +} + +func (t *fakeT) Logf(format string, args ...interface{}) { + fmt.Fprintf(&t.logBuf, format, args...) + t.logBuf.WriteString("\n") +} + +func (t *fakeT) FailNow() { + t.failed = true +} + +func (t *fakeT) Fatalf(format string, args ...interface{}) { + t.Logf(format, args...) + t.FailNow() +} + +func TestMustFindMetric(t *testing.T) { + reg := prom.NewRegistry(zaptest.NewLogger(t)) + reg.MustRegister(helperCollectors()...) + + mfs, err := reg.Gather() + if err != nil { + t.Fatal(err) + } + + ft := new(fakeT) + + // Doesn't log when metric is found. + _ = promtest.MustFindMetric(ft, mfs, "my_random_counter", nil) + if ft.failed { + t.Fatalf("MustFindMetric failed when it should not have. message: %s", ft.logBuf.String()) + } + + // Logs and fails when family name not found. + ft = new(fakeT) + _ = promtest.MustFindMetric(ft, mfs, "missing_name", nil) + if !ft.failed { + t.Fatal("MustFindMetric should have failed but didn't") + } + logged := ft.logBuf.String() + if !strings.Contains(logged, `name "missing_name" not found`) { + t.Fatalf("did not log the looked up name which was not found. message: %s", logged) + } + if !strings.Contains(logged, "my_random_counter") || !strings.Contains(logged, "my_random_gaugevec") { + t.Fatalf("did not log the available metric names. message: %s", logged) + } + + // Logs and fails when family name found but metric labels mismatch. + ft = new(fakeT) + _ = promtest.MustFindMetric(ft, mfs, "my_random_counter", map[string]string{"unknown": "label"}) + if !ft.failed { + t.Fatal("MustFindMetric should have failed but didn't") + } + + ft = new(fakeT) + _ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"unknown": "label"}) + if !ft.failed { + t.Fatal("MustFindMetric should have failed but didn't") + } + logged = ft.logBuf.String() + if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) { + t.Fatalf("did not log the available label names. message: %s", logged) + } + + ft = new(fakeT) + _ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two", "label3": "imaginary"}) + if !ft.failed { + t.Fatal("MustFindMetric should have failed but didn't") + } + logged = ft.logBuf.String() + if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) { + t.Fatalf("did not log the available label names. message: %s", logged) + } +} + +func TestMustGather(t *testing.T) { + expErr := errors.New("failed to gather") + g := prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) { + return nil, expErr + }) + + ft := new(fakeT) + _ = promtest.MustGather(ft, g) + if !ft.failed { + t.Fatal("MustGather should have failed but didn't") + } + logged := ft.logBuf.String() + if !strings.HasPrefix(logged, "error while gathering metrics:") || !strings.Contains(logged, expErr.Error()) { + t.Fatalf("did not log the expected error message: %s", logged) + } + + expMF := []*dto.MetricFamily{} // Use a non-nil, zero-length slice for a simple-ish check. + g = prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) { + return expMF, nil + }) + ft = new(fakeT) + gotMF := promtest.MustGather(ft, g) + if ft.failed { + t.Fatalf("MustGather should not have failed") + } + if gotMF == nil || len(gotMF) != 0 { + t.Fatalf("exp: %v, got: %v", expMF, gotMF) + } +} + +func TestFromHTTPResponse(t *testing.T) { + reg := prom.NewRegistry(zaptest.NewLogger(t)) + reg.MustRegister(helperCollectors()...) + + s := httptest.NewServer(reg.HTTPHandler()) + defer s.Close() + + resp, err := http.Get(s.URL) // Didn't specify a path for the handler, so any path should be fine. + if err != nil { + t.Fatal(err) + } + + mfs, err := promtest.FromHTTPResponse(resp) + if err != nil { + t.Fatal(err) + } + + if len(mfs) != 2 { + t.Fatalf("expected 2 metrics but got %d", len(mfs)) + } + + c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil) + if got := c.GetCounter().GetValue(); got != 1 { + t.Fatalf("expected counter to be 1, got %v", got) + } + + g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"}) + if got := g.GetGauge().GetValue(); got != 3 { + t.Fatalf("expected gauge to be 3, got %v", got) + } +} diff --git a/kit/prom/registry.go b/kit/prom/registry.go new file mode 100644 index 0000000000..f734ce4c5d --- /dev/null +++ b/kit/prom/registry.go @@ -0,0 +1,59 @@ +// Package prom provides a wrapper around a prometheus metrics registry +// so that all services are unified in how they expose prometheus metrics. +package prom + +import ( + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.uber.org/zap" +) + +// PrometheusCollector is the interface for a type to expose prometheus metrics. +// This interface is provided as a convention, so that you can optionally check +// if a type implements it and then pass its collectors to (*Registry).MustRegister. +type PrometheusCollector interface { + // PrometheusCollectors returns a slice of prometheus collectors + // containing metrics for the underlying instance. + PrometheusCollectors() []prometheus.Collector +} + +// Registry embeds a prometheus registry and adds a couple convenience methods. +type Registry struct { + *prometheus.Registry + + log *zap.Logger +} + +// NewRegistry returns a new registry. +func NewRegistry(log *zap.Logger) *Registry { + return &Registry{ + Registry: prometheus.NewRegistry(), + log: log, + } +} + +// HTTPHandler returns an http.Handler for the registry, +// so that the /metrics HTTP handler is uniformly configured across all apps in the platform. +func (r *Registry) HTTPHandler() http.Handler { + opts := promhttp.HandlerOpts{ + ErrorLog: promLogger{r: r}, + // TODO(mr): decide if we want to set MaxRequestsInFlight or Timeout. + } + return promhttp.HandlerFor(r.Registry, opts) +} + +// promLogger satisfies the promhttp.logger interface with the registry. +// Because normal usage is that WithLogger is called after HTTPHandler, +// we refer to the Registry rather than its logger. +type promLogger struct { + r *Registry +} + +var _ promhttp.Logger = (*promLogger)(nil) + +// Println implements promhttp.logger. +func (pl promLogger) Println(v ...interface{}) { + pl.r.log.Sugar().Info(v...) +} diff --git a/kit/prom/registry_test.go b/kit/prom/registry_test.go new file mode 100644 index 0000000000..b53d810f96 --- /dev/null +++ b/kit/prom/registry_test.go @@ -0,0 +1,60 @@ +package prom_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/influxdata/influxdb/kit/prom" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestRegistry_Logger(t *testing.T) { + core, logs := observer.New(zap.DebugLevel) + reg := prom.NewRegistry(zap.New(core)) + + // Normal use: HTTP handler is created immediately... + s := httptest.NewServer(reg.HTTPHandler()) + defer s.Close() + + // Force an error with a fake collector. + reg.MustRegister(errorCollector{}) + resp, err := http.Get(s.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + foundLog := false + for _, le := range logs.All() { + if strings.Contains(le.Message, "invalid metric from errorCollector") { + foundLog = true + break + } + } + + if !foundLog { + t.Fatalf("registry logger did not log error from metric collection") + } +} + +type errorCollector struct{} + +var _ prometheus.Collector = errorCollector{} + +var ecDesc = prometheus.NewDesc("error_collector_desc", "A required description for the error collector", nil, nil) + +func (errorCollector) Describe(ch chan<- *prometheus.Desc) { + ch <- ecDesc +} + +func (errorCollector) Collect(ch chan<- prometheus.Metric) { + ch <- prometheus.NewInvalidMetric( + ecDesc, + errors.New("invalid metric from errorCollector"), + ) +} diff --git a/kit/signals/context.go b/kit/signals/context.go new file mode 100644 index 0000000000..7541ac0547 --- /dev/null +++ b/kit/signals/context.go @@ -0,0 +1,31 @@ +package signals + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// WithSignals returns a context that is canceled with any signal in sigs. +func WithSignals(ctx context.Context, sigs ...os.Signal) context.Context { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, sigs...) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + select { + case <-ctx.Done(): + return + case <-sigCh: + return + } + }() + return ctx +} + +// WithStandardSignals cancels the context on os.Interrupt, syscall.SIGTERM. +func WithStandardSignals(ctx context.Context) context.Context { + return WithSignals(ctx, os.Interrupt, syscall.SIGTERM) +} diff --git a/kit/signals/context_test.go b/kit/signals/context_test.go new file mode 100644 index 0000000000..222e3989f8 --- /dev/null +++ b/kit/signals/context_test.go @@ -0,0 +1,79 @@ +package signals + +import ( + "context" + "fmt" + "os" + "syscall" + "testing" + "time" +) + +func ExampleWithSignals() { + ctx := WithSignals(context.Background(), syscall.SIGUSR1) + go func() { + time.Sleep(500 * time.Millisecond) // after some time SIGUSR1 is sent + // mimicking a signal from the outside + syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) + }() + + <-ctx.Done() + fmt.Println("finished") + // Output: + // finished +} + +func Example_withUnregisteredSignals() { + dctx, cancel := context.WithTimeout(context.TODO(), time.Millisecond*100) + defer cancel() + + ctx := WithSignals(dctx, syscall.SIGUSR1) + go func() { + time.Sleep(10 * time.Millisecond) // after some time SIGUSR2 is sent + // mimicking a signal from the outside, WithSignals will not handle it + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + }() + + <-ctx.Done() + fmt.Println("finished") + // Output: + // finished +} + +func TestWithSignals(t *testing.T) { + tests := []struct { + name string + ctx context.Context + sigs []os.Signal + wantSignal bool + }{ + { + name: "sending signal SIGUSR2 should exit context.", + ctx: context.Background(), + sigs: []os.Signal{syscall.SIGUSR2}, + wantSignal: true, + }, + { + name: "sending signal SIGUSR2 should NOT exit context.", + ctx: context.Background(), + sigs: []os.Signal{syscall.SIGUSR1}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := WithSignals(tt.ctx, tt.sigs...) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-ctx.Done(): + if !tt.wantSignal { + t.Errorf("unexpected exit with signal") + } + case <-timer.C: + if tt.wantSignal { + t.Errorf("expected to exit with signal but did not") + } + } + }) + } +} diff --git a/kit/tracing/testing/testing.go b/kit/tracing/testing/testing.go new file mode 100644 index 0000000000..af59cf0639 --- /dev/null +++ b/kit/tracing/testing/testing.go @@ -0,0 +1,24 @@ +package testing + +import ( + "github.com/opentracing/opentracing-go" + "github.com/uber/jaeger-client-go" +) + +// SetupInMemoryTracing sets the global tracer to an in memory Jaeger instance for testing. +// The returned function should be deferred by the caller to tear down this setup after testing is complete. +func SetupInMemoryTracing(name string) func() { + var ( + old = opentracing.GlobalTracer() + tracer, closer = jaeger.NewTracer(name, + jaeger.NewConstSampler(true), + jaeger.NewInMemoryReporter(), + ) + ) + + opentracing.SetGlobalTracer(tracer) + return func() { + _ = closer.Close() + opentracing.SetGlobalTracer(old) + } +} diff --git a/kit/tracing/tracing.go b/kit/tracing/tracing.go new file mode 100644 index 0000000000..45d99b8c50 --- /dev/null +++ b/kit/tracing/tracing.go @@ -0,0 +1,225 @@ +package tracing + +import ( + "context" + "errors" + "net/http" + "runtime" + "strings" + "time" + + "github.com/go-chi/chi" + "github.com/prometheus/client_golang/prometheus" + "github.com/uber/jaeger-client-go" + + "github.com/influxdata/httprouter" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + "github.com/opentracing/opentracing-go/log" +) + +// LogError adds a span log for an error. +// Returns unchanged error, so useful to wrap as in: +// return 0, tracing.LogError(err) +func LogError(span opentracing.Span, err error) error { + if err == nil { + return nil + } + + // Get caller frame. + var pcs [1]uintptr + n := runtime.Callers(2, pcs[:]) + if n < 1 { + span.LogFields(log.Error(err)) + span.LogFields(log.Error(errors.New("runtime.Callers failed"))) + return err + } + + file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0]) + span.LogFields(log.String("filename", file), log.Int("line", line), log.Error(err)) + + return err +} + +// InjectToHTTPRequest adds tracing headers to an HTTP request. +// Easier than adding this boilerplate everywhere. +func InjectToHTTPRequest(span opentracing.Span, req *http.Request) { + err := opentracing.GlobalTracer().Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) + if err != nil { + LogError(span, err) + } +} + +// ExtractFromHTTPRequest gets a child span of the parent referenced in HTTP request headers. +// Returns the request with updated tracing context. +// Easier than adding this boilerplate everywhere. +func ExtractFromHTTPRequest(req *http.Request, handlerName string) (opentracing.Span, *http.Request) { + spanContext, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) + if err != nil { + span, ctx := opentracing.StartSpanFromContext(req.Context(), "request") + annotateSpan(span, handlerName, req) + + _ = LogError(span, err) + + return span, req.WithContext(ctx) + } + + span := opentracing.StartSpan("request", opentracing.ChildOf(spanContext), ext.RPCServerOption(spanContext)) + annotateSpan(span, handlerName, req) + + return span, req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) +} + +func annotateSpan(span opentracing.Span, handlerName string, req *http.Request) { + if route := httprouter.MatchedRouteFromContext(req.Context()); route != "" { + span.SetTag("route", route) + } + span.SetTag("method", req.Method) + if ctx := chi.RouteContext(req.Context()); ctx != nil { + span.SetTag("route", ctx.RoutePath) + } + span.SetTag("handler", handlerName) + span.LogKV("path", req.URL.Path) +} + +// span is a simple wrapper around opentracing.Span in order to +// get access to the duration of the span for metrics reporting. +type Span struct { + opentracing.Span + start time.Time + Duration time.Duration + hist prometheus.Observer + gauge prometheus.Gauge +} + +func StartSpanFromContextWithPromMetrics(ctx context.Context, operationName string, hist prometheus.Observer, gauge prometheus.Gauge, opts ...opentracing.StartSpanOption) (*Span, context.Context) { + start := time.Now() + s, sctx := StartSpanFromContextWithOperationName(ctx, operationName, opentracing.StartTime(start)) + gauge.Inc() + return &Span{s, start, 0, hist, gauge}, sctx +} + +func (s *Span) Finish() { + finish := time.Now() + s.Duration = finish.Sub(s.start) + s.Span.FinishWithOptions(opentracing.FinishOptions{ + FinishTime: finish, + }) + s.hist.Observe(s.Duration.Seconds()) + s.gauge.Dec() +} + +// StartSpanFromContext is an improved opentracing.StartSpanFromContext. +// Uses the calling function as the operation name, and logs the filename and line number. +// +// Passing nil context induces panic. +// Context without parent span reference triggers root span construction. +// This function never returns nil values. +// +// Performance +// +// This function incurs a small performance penalty, roughly 1000 ns/op, 376 B/op, 6 allocs/op. +// Jaeger timestamp and duration precision is only µs, so this is pretty negligible. +// +// Alternatives +// +// If this performance penalty is too much, try these, which are also demonstrated in benchmark tests: +// // Create a root span +// span := opentracing.StartSpan("operation name") +// ctx := opentracing.ContextWithSpan(context.Background(), span) +// +// // Create a child span +// span := opentracing.StartSpan("operation name", opentracing.ChildOf(sc)) +// ctx := opentracing.ContextWithSpan(context.Background(), span) +// +// // Sugar to create a child span +// span, ctx := opentracing.StartSpanFromContext(ctx, "operation name") +func StartSpanFromContext(ctx context.Context, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) { + if ctx == nil { + panic("StartSpanFromContext called with nil context") + } + + // Get caller frame. + var pcs [1]uintptr + n := runtime.Callers(2, pcs[:]) + if n < 1 { + span, ctx := opentracing.StartSpanFromContext(ctx, "unknown", opts...) + span.LogFields(log.Error(errors.New("runtime.Callers failed"))) + return span, ctx + } + fn := runtime.FuncForPC(pcs[0]) + name := fn.Name() + if lastSlash := strings.LastIndexByte(name, '/'); lastSlash > 0 { + name = name[lastSlash+1:] + } + + var span opentracing.Span + if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil { + // Create a child span. + opts = append(opts, opentracing.ChildOf(parentSpan.Context())) + span = opentracing.StartSpan(name, opts...) + } else { + // Create a root span. + span = opentracing.StartSpan(name) + } + // New context references this span, not the parent (if there was one). + ctx = opentracing.ContextWithSpan(ctx, span) + + file, line := fn.FileLine(pcs[0]) + span.LogFields(log.String("filename", file), log.Int("line", line)) + + return span, ctx +} + +// StartSpanFromContextWithOperationName is like StartSpanFromContext, but the caller determines the operation name. +func StartSpanFromContextWithOperationName(ctx context.Context, operationName string, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) { + if ctx == nil { + panic("StartSpanFromContextWithOperationName called with nil context") + } + + // Get caller frame. + var pcs [1]uintptr + n := runtime.Callers(2, pcs[:]) + if n < 1 { + span, ctx := opentracing.StartSpanFromContext(ctx, operationName, opts...) + span.LogFields(log.Error(errors.New("runtime.Callers failed"))) + return span, ctx + } + file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0]) + + var span opentracing.Span + if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil { + opts = append(opts, opentracing.ChildOf(parentSpan.Context())) + // Create a child span. + span = opentracing.StartSpan(operationName, opts...) + } else { + // Create a root span. + span = opentracing.StartSpan(operationName, opts...) + } + // New context references this span, not the parent (if there was one). + ctx = opentracing.ContextWithSpan(ctx, span) + + span.LogFields(log.String("filename", file), log.Int("line", line)) + + return span, ctx +} + +// InfoFromSpan returns the traceID and if it was sampled from the span, given it is a jaeger span. +// It returns whether a span associated to the context has been found. +func InfoFromSpan(span opentracing.Span) (traceID string, sampled bool, found bool) { + if spanContext, ok := span.Context().(jaeger.SpanContext); ok { + traceID = spanContext.TraceID().String() + sampled = spanContext.IsSampled() + return traceID, sampled, true + } + return "", false, false +} + +// InfoFromContext returns the traceID and if it was sampled from the Jaeger span +// found in the given context. It returns whether a span associated to the context has been found. +func InfoFromContext(ctx context.Context) (traceID string, sampled bool, found bool) { + if span := opentracing.SpanFromContext(ctx); span != nil { + return InfoFromSpan(span) + } + return "", false, false +} diff --git a/kit/tracing/tracing_test.go b/kit/tracing/tracing_test.go new file mode 100644 index 0000000000..50ea7f3d78 --- /dev/null +++ b/kit/tracing/tracing_test.go @@ -0,0 +1,333 @@ +package tracing + +import ( + "context" + "fmt" + "net/http" + "net/url" + "runtime" + "testing" + + "github.com/go-chi/chi" + "github.com/influxdata/httprouter" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" +) + +func TestInjectAndExtractHTTPRequest(t *testing.T) { + tracer := mocktracer.New() + + oldTracer := opentracing.GlobalTracer() + opentracing.SetGlobalTracer(tracer) + defer opentracing.SetGlobalTracer(oldTracer) + + request, err := http.NewRequest(http.MethodPost, "http://localhost/", nil) + if err != nil { + t.Fatal(err) + } + + span := tracer.StartSpan("operation name") + + InjectToHTTPRequest(span, request) + gotSpan, _ := ExtractFromHTTPRequest(request, "MyStruct") + + if span.(*mocktracer.MockSpan).SpanContext.TraceID != gotSpan.(*mocktracer.MockSpan).SpanContext.TraceID { + t.Error("injected and extracted traceIDs not equal") + } + + if span.(*mocktracer.MockSpan).SpanContext.SpanID != gotSpan.(*mocktracer.MockSpan).ParentID { + t.Error("injected span ID does not match extracted span parent ID") + } +} + +func TestExtractHTTPRequest(t *testing.T) { + var ( + tracer = mocktracer.New() + oldTracer = opentracing.GlobalTracer() + ctx = context.Background() + ) + + opentracing.SetGlobalTracer(tracer) + defer opentracing.SetGlobalTracer(oldTracer) + + for _, test := range []struct { + name string + handlerName string + path string + ctx context.Context + tags map[string]interface{} + method string + }{ + { + name: "happy path", + handlerName: "WriteHandler", + ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/write"), + method: http.MethodGet, + path: "/api/v2/write", + tags: map[string]interface{}{ + "route": "/api/v2/write", + "handler": "WriteHandler", + }, + }, + { + name: "happy path bucket handler", + handlerName: "BucketHandler", + ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/buckets/:bucket_id"), + path: "/api/v2/buckets/12345", + method: http.MethodGet, + tags: map[string]interface{}{ + "route": "/api/v2/buckets/:bucket_id", + "handler": "BucketHandler", + }, + }, + { + name: "happy path bucket handler (chi)", + handlerName: "BucketHandler", + ctx: context.WithValue( + ctx, + chi.RouteCtxKey, + &chi.Context{RoutePath: "/api/v2/buckets/:bucket_id", RouteMethod: "GET"}, + ), + path: "/api/v2/buckets/12345", + method: http.MethodGet, + tags: map[string]interface{}{ + "route": "/api/v2/buckets/:bucket_id", + "method": "GET", + "handler": "BucketHandler", + }, + }, + { + name: "empty path", + handlerName: "Home", + ctx: ctx, + method: http.MethodGet, + tags: map[string]interface{}{ + "handler": "Home", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + request, err := http.NewRequest(test.method, "http://localhost"+test.path, nil) + if err != nil { + t.Fatal(err) + } + + span := tracer.StartSpan("operation name") + + InjectToHTTPRequest(span, request) + gotSpan, _ := ExtractFromHTTPRequest(request.WithContext(test.ctx), test.handlerName) + + if op := gotSpan.(*mocktracer.MockSpan).OperationName; op != "request" { + t.Fatalf("operation name %q != request", op) + } + + tags := gotSpan.(*mocktracer.MockSpan).Tags() + for k, v := range test.tags { + found, ok := tags[k] + if !ok { + t.Errorf("tag not found in span %q", k) + continue + } + + if found != v { + t.Errorf("expected %v, found %v for tag %q", v, found, k) + } + } + }) + } +} + +func TestStartSpanFromContext(t *testing.T) { + tracer := mocktracer.New() + + oldTracer := opentracing.GlobalTracer() + opentracing.SetGlobalTracer(tracer) + defer opentracing.SetGlobalTracer(oldTracer) + + type testCase struct { + ctx context.Context + expectPanic bool + expectParent bool + } + var testCases []testCase + + testCases = append(testCases, + testCase{ + ctx: nil, + expectPanic: true, + expectParent: false, + }, + testCase{ + ctx: context.Background(), + expectPanic: false, + expectParent: false, + }) + + parentSpan := opentracing.StartSpan("parent operation name") + testCases = append(testCases, testCase{ + ctx: opentracing.ContextWithSpan(context.Background(), parentSpan), + expectPanic: false, + expectParent: true, + }) + + for i, tc := range testCases { + t.Run(fmt.Sprint(i), func(t *testing.T) { + var span opentracing.Span + var ctx context.Context + var gotPanic bool + + func(inputCtx context.Context) { + defer func() { + if recover() != nil { + gotPanic = true + } + }() + span, ctx = StartSpanFromContext(inputCtx) + }(tc.ctx) + + if tc.expectPanic != gotPanic { + t.Errorf("panic: expect %v got %v", tc.expectPanic, gotPanic) + } + if tc.expectPanic { + // No other valid checks if panic. + return + } + if ctx == nil { + t.Error("never expect non-nil ctx") + } + if span == nil { + t.Error("never expect non-nil Span") + } + foundParent := span.(*mocktracer.MockSpan).ParentID != 0 + if tc.expectParent != foundParent { + t.Errorf("parent: expect %v got %v", tc.expectParent, foundParent) + } + if ctx == tc.ctx { + t.Errorf("always expect fresh context") + } + }) + } +} + +func TestLogErrorNil(t *testing.T) { + tracer := mocktracer.New() + span := tracer.StartSpan("test").(*mocktracer.MockSpan) + + var err error + if err2 := LogError(span, err); err2 != nil { + t.Errorf("expected nil err, got '%s'", err2.Error()) + } + + if len(span.Logs()) > 0 { + t.Errorf("expected zero new span logs, got %d", len(span.Logs())) + println(span.Logs()[0].Fields[0].Key) + } +} + +/* +BenchmarkLocal_StartSpanFromContext-8 2000000 681 ns/op 224 B/op 4 allocs/op +BenchmarkLocal_StartSpanFromContext_runtimeCaller-8 3000000 534 ns/op +BenchmarkLocal_StartSpanFromContext_runtimeCallers-8 10000000 196 ns/op +BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC-8 200000000 7.28 ns/op +BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames-8 10000000 234 ns/op +BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine-8 20000000 103 ns/op +BenchmarkOpentracing_StartSpanFromContext-8 10000000 155 ns/op 96 B/op 3 allocs/op +BenchmarkOpentracing_StartSpan_root-8 200000000 7.68 ns/op 0 B/op 0 allocs/op +BenchmarkOpentracing_StartSpan_child-8 20000000 71.2 ns/op 48 B/op 2 allocs/op +*/ + +func BenchmarkLocal_StartSpanFromContext(b *testing.B) { + b.ReportAllocs() + + parentSpan := opentracing.StartSpan("parent operation name") + ctx := opentracing.ContextWithSpan(context.Background(), parentSpan) + + for n := 0; n < b.N; n++ { + StartSpanFromContext(ctx) + } +} + +func BenchmarkLocal_StartSpanFromContext_runtimeCaller(b *testing.B) { + for n := 0; n < b.N; n++ { + _, _, _, _ = runtime.Caller(1) + } +} + +func BenchmarkLocal_StartSpanFromContext_runtimeCallers(b *testing.B) { + var pcs [1]uintptr + + for n := 0; n < b.N; n++ { + _ = runtime.Callers(2, pcs[:]) + } +} + +func BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC(b *testing.B) { + var pcs [1]uintptr + _ = runtime.Callers(2, pcs[:]) + + for n := 0; n < b.N; n++ { + _ = runtime.FuncForPC(pcs[0]) + } +} + +func BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames(b *testing.B) { + pc, _, _, ok := runtime.Caller(1) + if !ok { + b.Fatal("runtime.Caller failed") + } + + for n := 0; n < b.N; n++ { + _, _ = runtime.CallersFrames([]uintptr{pc}).Next() + } +} + +func BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine(b *testing.B) { + var pcs [1]uintptr + _ = runtime.Callers(2, pcs[:]) + fn := runtime.FuncForPC(pcs[0]) + + for n := 0; n < b.N; n++ { + _, _ = fn.FileLine(pcs[0]) + } +} + +func BenchmarkOpentracing_StartSpanFromContext(b *testing.B) { + b.ReportAllocs() + + parentSpan := opentracing.StartSpan("parent operation name") + ctx := opentracing.ContextWithSpan(context.Background(), parentSpan) + + for n := 0; n < b.N; n++ { + _, _ = opentracing.StartSpanFromContext(ctx, "operation name") + } +} + +func BenchmarkOpentracing_StartSpan_root(b *testing.B) { + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + _ = opentracing.StartSpan("operation name") + } +} + +func BenchmarkOpentracing_StartSpan_child(b *testing.B) { + b.ReportAllocs() + + parentSpan := opentracing.StartSpan("parent operation name") + + for n := 0; n < b.N; n++ { + _ = opentracing.StartSpan("operation name", opentracing.ChildOf(parentSpan.Context())) + } +} + +func BenchmarkOpentracing_ExtractFromHTTPRequest(b *testing.B) { + b.ReportAllocs() + + req := &http.Request{ + URL: &url.URL{Path: "/api/v2/organization/12345"}, + } + + for n := 0; n < b.N; n++ { + _, _ = ExtractFromHTTPRequest(req, "OrganizationHandler") + } +} diff --git a/kit/transport/http/api.go b/kit/transport/http/api.go new file mode 100644 index 0000000000..64adfef76f --- /dev/null +++ b/kit/transport/http/api.go @@ -0,0 +1,267 @@ +package http + +import ( + "compress/gzip" + "context" + "encoding/gob" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/influxdata/influxdb/kit/platform/errors" + + "go.uber.org/zap" +) + +// PlatformErrorCodeHeader shows the error code of platform error. +const PlatformErrorCodeHeader = "X-Platform-Error-Code" + +// API provides a consolidated means for handling API interface concerns. +// Concerns such as decoding/encoding request and response bodies as well +// as adding headers for content type and content encoding. +type API struct { + logger *zap.Logger + + prettyJSON bool + encodeGZIP bool + + unmarshalErrFn func(encoding string, err error) error + okErrFn func(err error) error + errFn func(ctx context.Context, err error) (interface{}, int, error) +} + +// APIOptFn is a functional option for setting fields on the API type. +type APIOptFn func(*API) + +// WithLog sets the logger. +func WithLog(logger *zap.Logger) APIOptFn { + return func(api *API) { + api.logger = logger + } +} + +// WithErrFn sets the err handling func for issues when writing to the response body. +func WithErrFn(fn func(ctx context.Context, err error) (interface{}, int, error)) APIOptFn { + return func(api *API) { + api.errFn = fn + } +} + +// WithOKErrFn is an error handler for failing validation for request bodies. +func WithOKErrFn(fn func(err error) error) APIOptFn { + return func(api *API) { + api.okErrFn = fn + } +} + +// WithPrettyJSON sets the json encoder to marshal indent or not. +func WithPrettyJSON(b bool) APIOptFn { + return func(api *API) { + api.prettyJSON = b + } +} + +// WithEncodeGZIP sets the encoder to gzip contents. +func WithEncodeGZIP() APIOptFn { + return func(api *API) { + api.encodeGZIP = true + } +} + +// WithUnmarshalErrFn sets the error handler for errors that occur when unmarshalling +// the request body. +func WithUnmarshalErrFn(fn func(encoding string, err error) error) APIOptFn { + return func(api *API) { + api.unmarshalErrFn = fn + } +} + +// NewAPI creates a new API type. +func NewAPI(opts ...APIOptFn) *API { + api := API{ + logger: zap.NewNop(), + prettyJSON: true, + unmarshalErrFn: func(encoding string, err error) error { + return &errors.Error{ + Code: errors.EInvalid, + Msg: fmt.Sprintf("failed to unmarshal %s: %s", encoding, err), + } + }, + errFn: func(ctx context.Context, err error) (interface{}, int, error) { + msg := err.Error() + if msg == "" { + msg = "an internal error has occurred" + } + code := errors.ErrorCode(err) + return ErrBody{ + Code: code, + Msg: msg, + }, ErrorCodeToStatusCode(ctx, code), nil + }, + } + for _, o := range opts { + o(&api) + } + return &api +} + +// DecodeJSON decodes reader with json. +func (a *API) DecodeJSON(r io.Reader, v interface{}) error { + return a.decode("json", json.NewDecoder(r), v) +} + +// DecodeGob decodes reader with gob. +func (a *API) DecodeGob(r io.Reader, v interface{}) error { + return a.decode("gob", gob.NewDecoder(r), v) +} + +type ( + decoder interface { + Decode(interface{}) error + } + + oker interface { + OK() error + } +) + +func (a *API) decode(encoding string, dec decoder, v interface{}) error { + if err := dec.Decode(v); err != nil { + if a != nil && a.unmarshalErrFn != nil { + return a.unmarshalErrFn(encoding, err) + } + return err + } + + if vv, ok := v.(oker); ok { + err := vv.OK() + if a != nil && a.okErrFn != nil { + return a.okErrFn(err) + } + return err + } + + return nil +} + +// Respond writes to the response writer, handling all errors in writing. +func (a *API) Respond(w http.ResponseWriter, r *http.Request, status int, v interface{}) { + if status == http.StatusNoContent { + w.WriteHeader(status) + return + } + + var writer io.WriteCloser = noopCloser{Writer: w} + // we'll double close to make sure its always closed even + //on issues before the write + defer writer.Close() + + if a != nil && a.encodeGZIP { + w.Header().Set("Content-Encoding", "gzip") + writer = gzip.NewWriter(w) + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + // this marshal block is to catch failures before they hit the http writer. + // default behavior for http.ResponseWriter is when body is written and no + // status is set, it writes a 200. Or if a status is set before encoding + // and an error occurs, there is no means to write a proper status code + // (i.e. 500) when that is to occur. This brings that step out before + // and then writes the data and sets the status code after marshaling + // succeeds. + var ( + b []byte + err error + ) + if a == nil || a.prettyJSON { + b, err = json.MarshalIndent(v, "", "\t") + } else { + b, err = json.Marshal(v) + } + if err != nil { + a.Err(w, r, err) + return + } + + a.write(w, writer, status, b) +} + +// Write allows the user to write raw bytes to the response writer. This +// operation does not have a fail case, all failures here will be logged. +func (a *API) Write(w http.ResponseWriter, status int, b []byte) { + if status == http.StatusNoContent { + w.WriteHeader(status) + return + } + + var writer io.WriteCloser = noopCloser{Writer: w} + // we'll double close to make sure its always closed even + //on issues before the write + defer writer.Close() + + if a != nil && a.encodeGZIP { + w.Header().Set("Content-Encoding", "gzip") + writer = gzip.NewWriter(w) + } + + a.write(w, writer, status, b) +} + +func (a *API) write(w http.ResponseWriter, wc io.WriteCloser, status int, b []byte) { + w.WriteHeader(status) + if _, err := wc.Write(b); err != nil { + a.logErr("failed to write to response writer", zap.Error(err)) + } + + if err := wc.Close(); err != nil { + a.logErr("failed to close response writer", zap.Error(err)) + } +} + +// Err is used for writing an error to the response. +func (a *API) Err(w http.ResponseWriter, r *http.Request, err error) { + if err == nil { + return + } + + a.logErr("api error encountered", zap.Error(err)) + + v, status, err := a.errFn(r.Context(), err) + if err != nil { + a.logErr("failed to write err to response writer", zap.Error(err)) + a.Respond(w, r, http.StatusInternalServerError, ErrBody{ + Code: "internal error", + Msg: "an unexpected error occurred", + }) + return + } + + if eb, ok := v.(ErrBody); ok { + w.Header().Set(PlatformErrorCodeHeader, eb.Code) + } + + a.Respond(w, r, status, v) +} + +func (a *API) logErr(msg string, fields ...zap.Field) { + if a == nil || a.logger == nil { + return + } + a.logger.Error(msg, fields...) +} + +type noopCloser struct { + io.Writer +} + +func (n noopCloser) Close() error { + return nil +} + +// ErrBody is an err response body. +type ErrBody struct { + Code string `json:"code"` + Msg string `json:"message"` +} diff --git a/kit/transport/http/api_test.go b/kit/transport/http/api_test.go new file mode 100644 index 0000000000..5474ceed9a --- /dev/null +++ b/kit/transport/http/api_test.go @@ -0,0 +1,309 @@ +package http_test + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/influxdata/influxdb/kit/platform/errors" + kithttp "github.com/influxdata/influxdb/kit/transport/http" + "github.com/influxdata/influxdb/pkg/testttp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_API(t *testing.T) { + t.Run("Decode", func(t *testing.T) { + t.Run("valid foo no errors", func(t *testing.T) { + expected := validatFoo{ + Foo: "valid", + Bar: 10, + } + + t.Run("json", func(t *testing.T) { + var api *kithttp.API // shows it is safe to use a nil value + + var out validatFoo + err := api.DecodeJSON(encodeJSON(t, expected), &out) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + if expected != out { + t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out) + } + }) + + t.Run("gob", func(t *testing.T) { + var out validatFoo + err := kithttp.NewAPI().DecodeGob(encodeGob(t, expected), &out) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + if expected != out { + t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out) + } + }) + }) + + t.Run("unmarshals fine with ok error", func(t *testing.T) { + badFoo := validatFoo{ + Foo: "", + Bar: 0, + } + + t.Run("json", func(t *testing.T) { + var out validatFoo + err := kithttp.NewAPI().DecodeJSON(encodeJSON(t, badFoo), &out) + if err == nil { + t.Fatal("expected an err") + } + }) + + t.Run("gob", func(t *testing.T) { + var out validatFoo + err := kithttp.NewAPI().DecodeGob(encodeGob(t, badFoo), &out) + if err == nil { + t.Fatal("expected an err") + } + }) + }) + + t.Run("unmarshal error", func(t *testing.T) { + invalidBody := []byte("[}-{]") + + var out validatFoo + err := kithttp.NewAPI().DecodeJSON(bytes.NewReader(invalidBody), &out) + if err == nil { + t.Fatal("expected an error") + } + }) + + t.Run("unmarshal err fn wraps unmarshalling error", func(t *testing.T) { + t.Run("json", func(t *testing.T) { + invalidBody := []byte("[}-{]") + + api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn)) + + var out validatFoo + err := api.DecodeJSON(bytes.NewReader(invalidBody), &out) + expectInfluxdbError(t, errors.EInvalid, err) + }) + + t.Run("gob", func(t *testing.T) { + invalidBody := []byte("[}-{]") + + api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn)) + + var out validatFoo + err := api.DecodeGob(bytes.NewReader(invalidBody), &out) + expectInfluxdbError(t, errors.EInvalid, err) + }) + }) + + t.Run("ok error fn wraps ok error", func(t *testing.T) { + badFoo := validatFoo{Foo: ""} + + t.Run("json", func(t *testing.T) { + api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn)) + + var out validatFoo + err := api.DecodeJSON(encodeJSON(t, badFoo), &out) + expectInfluxdbError(t, errors.EUnprocessableEntity, err) + }) + + t.Run("gob", func(t *testing.T) { + api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn)) + + var out validatFoo + err := api.DecodeGob(encodeGob(t, badFoo), &out) + expectInfluxdbError(t, errors.EUnprocessableEntity, err) + }) + }) + }) + + t.Run("Respond", func(t *testing.T) { + tests := []int{ + http.StatusCreated, + http.StatusOK, + http.StatusNoContent, + http.StatusForbidden, + http.StatusInternalServerError, + } + + for _, statusCode := range tests { + fn := func(t *testing.T) { + responder := kithttp.NewAPI() + + svr := func(w http.ResponseWriter, r *http.Request) { + responder.Respond(w, r, statusCode, map[string]string{ + "foo": "bar", + }) + } + + expectedBodyFn := func(body *bytes.Buffer) { + var resp map[string]string + require.NoError(t, json.NewDecoder(body).Decode(&resp)) + assert.Equal(t, "bar", resp["foo"]) + } + if statusCode == http.StatusNoContent { + expectedBodyFn = func(body *bytes.Buffer) { + require.Zero(t, body.Len()) + } + } + + testttp. + Get(t, "/foo"). + Do(http.HandlerFunc(svr)). + ExpectStatus(statusCode). + ExpectBody(expectedBodyFn) + } + t.Run(http.StatusText(statusCode), fn) + } + }) + + t.Run("Err", func(t *testing.T) { + tests := []struct { + statusCode int + expectedErr *errors.Error + }{ + { + statusCode: http.StatusBadRequest, + expectedErr: &errors.Error{ + Code: errors.EInvalid, + Msg: "failed to unmarshal", + }, + }, + { + statusCode: http.StatusForbidden, + expectedErr: &errors.Error{ + Code: errors.EForbidden, + Msg: "forbidden", + }, + }, + { + statusCode: http.StatusUnprocessableEntity, + expectedErr: &errors.Error{ + Code: errors.EUnprocessableEntity, + Msg: "failed validation", + }, + }, + { + statusCode: http.StatusInternalServerError, + expectedErr: &errors.Error{ + Code: errors.EInternal, + Msg: "internal error", + }, + }, + } + + for _, tt := range tests { + fn := func(t *testing.T) { + responder := kithttp.NewAPI() + + svr := func(w http.ResponseWriter, r *http.Request) { + responder.Err(w, r, tt.expectedErr) + } + + testttp. + Get(t, "/foo"). + Do(http.HandlerFunc(svr)). + ExpectStatus(tt.statusCode). + ExpectBody(func(body *bytes.Buffer) { + var err kithttp.ErrBody + require.NoError(t, json.NewDecoder(body).Decode(&err)) + assert.Equal(t, tt.expectedErr.Msg, err.Msg) + assert.Equal(t, tt.expectedErr.Code, err.Code) + }) + } + t.Run(http.StatusText(tt.statusCode), fn) + } + }) +} + +func expectInfluxdbError(t *testing.T, expectedCode string, err error) { + t.Helper() + + if err == nil { + t.Fatal("expected an error") + } + iErr, ok := err.(*errors.Error) + if !ok { + t.Fatalf("expected an influxdb error; got=%#v", err) + } + + if got := iErr.Code; expectedCode != got { + t.Fatalf("unexpected error code; expected=%s got=%s", expectedCode, got) + } +} + +func encodeGob(t *testing.T, v interface{}) io.Reader { + t.Helper() + + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(v); err != nil { + t.Fatal(err) + } + return &buf +} + +func encodeJSON(t *testing.T, v interface{}) io.Reader { + t.Helper() + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(v); err != nil { + t.Fatal(err) + } + return &buf +} + +func okErrFn(err error) error { + return &errors.Error{ + Code: errors.EUnprocessableEntity, + Msg: "failed validation", + Err: err, + } +} + +func unmarshalErrFn(encoding string, err error) error { + return &errors.Error{ + Code: errors.EInvalid, + Msg: fmt.Sprintf("invalid %s request body", encoding), + Err: err, + } +} + +type validatFoo struct { + Foo string `gob:"foo"` + Bar int `gob:"bar"` +} + +func (v *validatFoo) OK() error { + var errs multiErr + if v.Foo == "" { + errs = append(errs, "foo must be at least 1 char") + } + if v.Bar < 0 { + errs = append(errs, "bar must be a positive real number") + } + return errs.toError() +} + +type multiErr []string + +func (m multiErr) toError() error { + if len(m) > 0 { + return m + } + return nil +} + +func (m multiErr) Error() string { + return strings.Join(m, "; ") +} diff --git a/kit/transport/http/error_handler.go b/kit/transport/http/error_handler.go new file mode 100644 index 0000000000..7f51ee6e61 --- /dev/null +++ b/kit/transport/http/error_handler.go @@ -0,0 +1,188 @@ +package http + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "net/http" + "strings" + + errors2 "github.com/influxdata/influxdb/kit/platform/errors" +) + +// ErrorHandler is the error handler in http package. +type ErrorHandler int + +// HandleHTTPError encodes err with the appropriate status code and format, +// sets the X-Platform-Error-Code headers on the response. +// We're no longer using X-Influx-Error and X-Influx-Reference. +// and sets the response status to the corresponding status code. +func (h ErrorHandler) HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter) { + if err == nil { + return + } + + code := errors2.ErrorCode(err) + var msg string + if err, ok := err.(*errors2.Error); ok { + msg = err.Error() + } else { + msg = "An internal error has occurred" + } + + WriteErrorResponse(ctx, w, code, msg) +} + +func WriteErrorResponse(ctx context.Context, w http.ResponseWriter, code string, msg string) { + w.Header().Set(PlatformErrorCodeHeader, code) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(ErrorCodeToStatusCode(ctx, code)) + e := struct { + Code string `json:"code"` + Message string `json:"message"` + }{ + Code: code, + Message: msg, + } + b, _ := json.Marshal(e) + _, _ = w.Write(b) +} + +// StatusCodeToErrorCode maps a http status code integer to an +// influxdb error code string. +func StatusCodeToErrorCode(statusCode int) string { + errorCode, ok := httpStatusCodeToInfluxDBError[statusCode] + if ok { + return errorCode + } + + return errors2.EInternal +} + +// ErrorCodeToStatusCode maps an influxdb error code string to a +// http status code integer. +func ErrorCodeToStatusCode(ctx context.Context, code string) int { + // If the client disconnects early or times out then return a different + // error than the passed in error code. Client timeouts return a 408 + // while disconnections return a non-standard Nginx HTTP 499 code. + if err := ctx.Err(); err == context.DeadlineExceeded { + return http.StatusRequestTimeout + } else if err == context.Canceled { + return 499 // https://httpstatuses.com/499 + } + + // Otherwise map internal error codes to HTTP status codes. + statusCode, ok := influxDBErrorToStatusCode[code] + if ok { + return statusCode + } + return http.StatusInternalServerError +} + +// influxDBErrorToStatusCode is a mapping of ErrorCode to http status code. +var influxDBErrorToStatusCode = map[string]int{ + errors2.EInternal: http.StatusInternalServerError, + errors2.ENotImplemented: http.StatusNotImplemented, + errors2.EInvalid: http.StatusBadRequest, + errors2.EUnprocessableEntity: http.StatusUnprocessableEntity, + errors2.EEmptyValue: http.StatusBadRequest, + errors2.EConflict: http.StatusUnprocessableEntity, + errors2.ENotFound: http.StatusNotFound, + errors2.EUnavailable: http.StatusServiceUnavailable, + errors2.EForbidden: http.StatusForbidden, + errors2.ETooManyRequests: http.StatusTooManyRequests, + errors2.EUnauthorized: http.StatusUnauthorized, + errors2.EMethodNotAllowed: http.StatusMethodNotAllowed, + errors2.ETooLarge: http.StatusRequestEntityTooLarge, +} + +var httpStatusCodeToInfluxDBError = map[int]string{} + +func init() { + for k, v := range influxDBErrorToStatusCode { + httpStatusCodeToInfluxDBError[v] = k + } +} + +// CheckErrorStatus for status and any error in the response. +func CheckErrorStatus(code int, res *http.Response) error { + err := CheckError(res) + if err != nil { + return err + } + + if res.StatusCode != code { + return fmt.Errorf("unexpected status code: %s", res.Status) + } + + return nil +} + +// CheckError reads the http.Response and returns an error if one exists. +// It will automatically recognize the errors returned by Influx services +// and decode the error into an internal error type. If the error cannot +// be determined in that way, it will create a generic error message. +// +// If there is no error, then this returns nil. +func CheckError(resp *http.Response) (err error) { + switch resp.StatusCode / 100 { + case 4, 5: + // We will attempt to parse this error outside of this block. + case 2: + return nil + default: + // TODO(jsternberg): Figure out what to do here? + return &errors2.Error{ + Code: errors2.EInternal, + Msg: fmt.Sprintf("unexpected status code: %d %s", resp.StatusCode, resp.Status), + } + } + + perr := &errors2.Error{ + Code: StatusCodeToErrorCode(resp.StatusCode), + } + + if resp.StatusCode == http.StatusUnsupportedMediaType { + perr.Msg = fmt.Sprintf("invalid media type: %q", resp.Header.Get("Content-Type")) + return perr + } + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + // Assume JSON if there is no content-type. + contentType = "application/json" + } + mediatype, _, _ := mime.ParseMediaType(contentType) + + var buf bytes.Buffer + if _, err := io.Copy(&buf, resp.Body); err != nil { + perr.Msg = "failed to read error response" + perr.Err = err + return perr + } + + switch mediatype { + case "application/json": + if err := json.Unmarshal(buf.Bytes(), perr); err != nil { + perr.Msg = fmt.Sprintf("attempted to unmarshal error as JSON but failed: %q", err) + perr.Err = firstLineAsError(buf) + } + default: + perr.Err = firstLineAsError(buf) + } + + if perr.Code == "" { + // given it was unset during attempt to unmarshal as JSON + perr.Code = StatusCodeToErrorCode(resp.StatusCode) + } + + return perr +} +func firstLineAsError(buf bytes.Buffer) error { + line, _ := buf.ReadString('\n') + return errors.New(strings.TrimSuffix(line, "\n")) +} diff --git a/kit/transport/http/error_handler_test.go b/kit/transport/http/error_handler_test.go new file mode 100644 index 0000000000..9e3c1a1c3d --- /dev/null +++ b/kit/transport/http/error_handler_test.go @@ -0,0 +1,56 @@ +package http_test + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "github.com/influxdata/influxdb/kit/platform/errors" + kithttp "github.com/influxdata/influxdb/kit/transport/http" +) + +func TestEncodeError(t *testing.T) { + ctx := context.TODO() + + w := httptest.NewRecorder() + + kithttp.ErrorHandler(0).HandleHTTPError(ctx, nil, w) + + if w.Code != 200 { + t.Errorf("expected status code 200, got: %d", w.Code) + } +} + +func TestEncodeErrorWithError(t *testing.T) { + ctx := context.TODO() + err := &errors.Error{ + Code: errors.EInternal, + Msg: "an error occurred", + Err: fmt.Errorf("there's an error here, be aware"), + } + + w := httptest.NewRecorder() + + kithttp.ErrorHandler(0).HandleHTTPError(ctx, err, w) + + if w.Code != 500 { + t.Errorf("expected status code 500, got: %d", w.Code) + } + + errHeader := w.Header().Get("X-Platform-Error-Code") + if errHeader != errors.EInternal { + t.Errorf("expected X-Platform-Error-Code: %s, got: %s", errors.EInternal, errHeader) + } + + // The http handler will flatten the message and it will not + // have an error property, so reading the serialization results + // in a different error. + pe := kithttp.CheckError(w.Result()).(*errors.Error) + if want, got := errors.EInternal, pe.Code; want != got { + t.Errorf("unexpected code -want/+got:\n\t- %q\n\t+ %q", want, got) + } + if want, got := "an error occurred: there's an error here, be aware", pe.Msg; want != got { + t.Errorf("unexpected message -want/+got:\n\t- %q\n\t+ %q", want, got) + } +} diff --git a/kit/transport/http/feature_controller.go b/kit/transport/http/feature_controller.go new file mode 100644 index 0000000000..ab10d1243d --- /dev/null +++ b/kit/transport/http/feature_controller.go @@ -0,0 +1,39 @@ +package http + +import ( + "context" + "net/http" + + "github.com/influxdata/influxdb/kit/feature" +) + +// Enabler allows the switching between two HTTP Handlers +type Enabler interface { + Enabled(ctx context.Context, flagger ...feature.Flagger) bool +} + +// FeatureHandler is used to switch requests between an existing and a feature flagged +// HTTP Handler on a per-request basis +type FeatureHandler struct { + enabler Enabler + flagger feature.Flagger + oldHandler http.Handler + newHandler http.Handler + prefix string +} + +func NewFeatureHandler(e Enabler, f feature.Flagger, old, new http.Handler, prefix string) *FeatureHandler { + return &FeatureHandler{e, f, old, new, prefix} +} + +func (h *FeatureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.enabler.Enabled(r.Context(), h.flagger) { + h.newHandler.ServeHTTP(w, r) + return + } + h.oldHandler.ServeHTTP(w, r) +} + +func (h *FeatureHandler) Prefix() string { + return h.prefix +} diff --git a/kit/transport/http/handler.go b/kit/transport/http/handler.go new file mode 100644 index 0000000000..ba83a61baf --- /dev/null +++ b/kit/transport/http/handler.go @@ -0,0 +1,11 @@ +package http + +import "net/http" + +// ResourceHandler is an HTTP handler for a resource. The prefix +// describes the url path prefix that relates to the handler +// endpoints. +type ResourceHandler interface { + Prefix() string + http.Handler +} diff --git a/kit/transport/http/middleware.go b/kit/transport/http/middleware.go new file mode 100644 index 0000000000..0fcc7ee302 --- /dev/null +++ b/kit/transport/http/middleware.go @@ -0,0 +1,182 @@ +package http + +import ( + "context" + "fmt" + "net/http" + "path" + "strings" + "time" + + "github.com/influxdata/influxdb/kit/platform" + "github.com/influxdata/influxdb/kit/platform/errors" + + "github.com/go-chi/chi" + "github.com/influxdata/influxdb/kit/tracing" + ua "github.com/mileusna/useragent" + "github.com/prometheus/client_golang/prometheus" +) + +// Middleware constructor. +type Middleware func(http.Handler) http.Handler + +func SetCORS(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if origin := r.Header.Get("Origin"); origin != "" { + // Access-Control-Allow-Origin must be present in every response + w.Header().Set("Access-Control-Allow-Origin", origin) + } + if r.Method == http.MethodOptions { + // allow and stop processing in pre-flight requests + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, User-Agent") + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +func Metrics(name string, reqMetric *prometheus.CounterVec, durMetric *prometheus.HistogramVec) Middleware { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + statusW := NewStatusResponseWriter(w) + + defer func(start time.Time) { + label := prometheus.Labels{ + "handler": name, + "method": r.Method, + "path": normalizePath(r.URL.Path), + "status": statusW.StatusCodeClass(), + "response_code": fmt.Sprintf("%d", statusW.Code()), + "user_agent": UserAgent(r), + } + durMetric.With(label).Observe(time.Since(start).Seconds()) + reqMetric.With(label).Inc() + }(time.Now()) + + next.ServeHTTP(statusW, r) + } + return http.HandlerFunc(fn) + } +} + +func SkipOptions(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // Preflight CORS requests from the browser will send an options request, + // so we need to make sure we satisfy them + if origin := r.Header.Get("Origin"); origin == "" && r.Method == http.MethodOptions { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +func Trace(name string) Middleware { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + span, r := tracing.ExtractFromHTTPRequest(r, name) + defer span.Finish() + + span.LogKV("user_agent", UserAgent(r)) + for k, v := range r.Header { + if len(v) == 0 { + continue + } + + if k == "Authorization" || k == "User-Agent" { + continue + } + + // If header has multiple values, only the first value will be logged on the trace. + span.LogKV(k, v[0]) + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +func UserAgent(r *http.Request) string { + header := r.Header.Get("User-Agent") + if header == "" { + return "unknown" + } + + return ua.Parse(header).Name +} + +func normalizePath(p string) string { + var parts []string + for head, tail := shiftPath(p); ; head, tail = shiftPath(tail) { + piece := head + if len(piece) == platform.IDLength { + if _, err := platform.IDFromString(head); err == nil { + piece = ":id" + } + } + parts = append(parts, piece) + if tail == "/" { + break + } + } + return "/" + path.Join(parts...) +} + +func shiftPath(p string) (head, tail string) { + p = path.Clean("/" + p) + i := strings.Index(p[1:], "/") + 1 + if i <= 0 { + return p[1:], "/" + } + return p[1:i], p[i:] +} + +type OrgContext string + +const CtxOrgKey OrgContext = "orgID" + +// ValidResource make sure a resource exists when a sub system needs to be mounted to an api +func ValidResource(api *API, lookupOrgByResourceID func(context.Context, platform.ID) (platform.ID, error)) Middleware { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + statusW := NewStatusResponseWriter(w) + id, err := platform.IDFromString(chi.URLParam(r, "id")) + if err != nil { + api.Err(w, r, platform.ErrCorruptID(err)) + return + } + + ctx := r.Context() + + orgID, err := lookupOrgByResourceID(ctx, *id) + if err != nil { + // if this function returns an error we will squash the error message and replace it with a not found error + api.Err(w, r, &errors.Error{ + Code: errors.ENotFound, + Msg: "404 page not found", + }) + return + } + + // embed OrgID into context + next.ServeHTTP(statusW, r.WithContext(context.WithValue(ctx, CtxOrgKey, orgID))) + } + return http.HandlerFunc(fn) + } +} + +// OrgIDFromContext .... +func OrgIDFromContext(ctx context.Context) *platform.ID { + v := ctx.Value(CtxOrgKey) + if v == nil { + return nil + } + id := v.(platform.ID) + return &id +} diff --git a/kit/transport/http/middleware_test.go b/kit/transport/http/middleware_test.go new file mode 100644 index 0000000000..fefadcf207 --- /dev/null +++ b/kit/transport/http/middleware_test.go @@ -0,0 +1,95 @@ +package http + +import ( + "net/http" + "path" + "testing" + + "github.com/influxdata/influxdb/kit/platform" + + "github.com/influxdata/influxdb/pkg/testttp" + "github.com/stretchr/testify/assert" +) + +func Test_normalizePath(t *testing.T) { + tests := []struct { + name string + path string + expected string + }{ + { + name: "1", + path: path.Join("/api/v2/organizations", platform.ID(2).String()), + expected: "/api/v2/organizations/:id", + }, + { + name: "2", + path: "/api/v2/organizations", + expected: "/api/v2/organizations", + }, + { + name: "3", + path: "/", + expected: "/", + }, + { + name: "4", + path: path.Join("/api/v2/organizations", platform.ID(2).String(), "users", platform.ID(3).String()), + expected: "/api/v2/organizations/:id/users/:id", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := normalizePath(tt.path) + assert.Equal(t, tt.expected, actual) + }) + } +} + +func TestCors(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("nextHandler")) + }) + + tests := []struct { + name string + method string + headers []string + expectedStatus int + expectedHeaders map[string]string + }{ + { + name: "OPTIONS without Origin", + method: "OPTIONS", + expectedStatus: http.StatusMethodNotAllowed, + }, + { + name: "OPTIONS with Origin", + method: "OPTIONS", + headers: []string{"Origin", "http://myapp.com"}, + expectedStatus: http.StatusNoContent, + }, + { + name: "GET with Origin", + method: "GET", + headers: []string{"Origin", "http://anotherapp.com"}, + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "http://anotherapp.com", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svr := SkipOptions(SetCORS(nextHandler)) + + testttp. + HTTP(t, tt.method, "/", nil). + Headers("", "", tt.headers...). + Do(svr). + ExpectStatus(tt.expectedStatus). + ExpectHeaders(tt.expectedHeaders) + }) + } +} diff --git a/kit/transport/http/status_response_writer.go b/kit/transport/http/status_response_writer.go new file mode 100644 index 0000000000..68efba4828 --- /dev/null +++ b/kit/transport/http/status_response_writer.go @@ -0,0 +1,58 @@ +package http + +import "net/http" + +type StatusResponseWriter struct { + statusCode int + responseBytes int + http.ResponseWriter +} + +func NewStatusResponseWriter(w http.ResponseWriter) *StatusResponseWriter { + return &StatusResponseWriter{ + ResponseWriter: w, + } +} + +func (w *StatusResponseWriter) Write(b []byte) (int, error) { + n, err := w.ResponseWriter.Write(b) + w.responseBytes += n + return n, err +} + +// WriteHeader writes the header and captures the status code. +func (w *StatusResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *StatusResponseWriter) Code() int { + code := w.statusCode + if code == 0 { + // When statusCode is 0 then WriteHeader was never called and we can assume that + // the ResponseWriter wrote an http.StatusOK. + code = http.StatusOK + } + return code +} + +func (w *StatusResponseWriter) ResponseBytes() int { + return w.responseBytes +} + +func (w *StatusResponseWriter) StatusCodeClass() string { + class := "XXX" + switch w.Code() / 100 { + case 1: + class = "1XX" + case 2: + class = "2XX" + case 3: + class = "3XX" + case 4: + class = "4XX" + case 5: + class = "5XX" + } + return class +} diff --git a/pkg/testttp/http.go b/pkg/testttp/http.go new file mode 100644 index 0000000000..05e4e408b6 --- /dev/null +++ b/pkg/testttp/http.go @@ -0,0 +1,195 @@ +package testttp + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +// Req is a request builder. +type Req struct { + t testing.TB + + req *http.Request +} + +// HTTP runs creates a request for an http call. +func HTTP(t testing.TB, method, addr string, body io.Reader) *Req { + return &Req{ + t: t, + req: httptest.NewRequest(method, addr, body), + } +} + +// Delete creates a DELETE request. +func Delete(t testing.TB, addr string) *Req { + return HTTP(t, http.MethodDelete, addr, nil) +} + +// Get creates a GET request. +func Get(t testing.TB, addr string) *Req { + return HTTP(t, http.MethodGet, addr, nil) +} + +// Patch creates a PATCH request. +func Patch(t testing.TB, addr string, body io.Reader) *Req { + return HTTP(t, http.MethodPatch, addr, body) +} + +// PatchJSON creates a PATCH request with a json encoded body. +func PatchJSON(t testing.TB, addr string, v interface{}) *Req { + return HTTP(t, http.MethodPatch, addr, mustEncodeJSON(t, v)) +} + +// Post creates a POST request. +func Post(t testing.TB, addr string, body io.Reader) *Req { + return HTTP(t, http.MethodPost, addr, body) +} + +// PostJSON returns a POST request with a json encoded body. +func PostJSON(t testing.TB, addr string, v interface{}) *Req { + return Post(t, addr, mustEncodeJSON(t, v)) +} + +// Put creates a PUT request. +func Put(t testing.TB, addr string, body io.Reader) *Req { + return HTTP(t, http.MethodPut, addr, body) +} + +// PutJSON creates a PUT request with a json encoded body. +func PutJSON(t testing.TB, addr string, v interface{}) *Req { + return HTTP(t, http.MethodPut, addr, mustEncodeJSON(t, v)) +} + +// Do runs the request against the provided handler. +func (r *Req) Do(handler http.Handler) *Resp { + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, r.req) + + return &Resp{ + t: r.t, + debug: true, + Req: r.req, + Rec: rec, + } +} + +func (r *Req) SetFormValue(k, v string) *Req { + if r.req.Form == nil { + r.req.Form = make(url.Values) + } + r.req.Form.Set(k, v) + return r +} + +// Headers allows the user to set headers on the http request. +func (r *Req) Headers(k, v string, rest ...string) *Req { + headers := append(rest, k, v) + for i := 0; i < len(headers); i += 2 { + if i+1 >= len(headers) { + break + } + k, v := headers[i], headers[i+1] + r.req.Header.Add(k, v) + } + return r +} + +// WithCtx sets the ctx on the request. +func (r *Req) WithCtx(ctx context.Context) *Req { + r.req = r.req.WithContext(ctx) + return r +} + +// WrapCtx provides means to wrap a request context. This is useful for stuffing in the +// auth stuffs that are required at times. +func (r *Req) WrapCtx(fn func(ctx context.Context) context.Context) *Req { + return r.WithCtx(fn(r.req.Context())) +} + +// Resp is a http recorder wrapper. +type Resp struct { + t testing.TB + + debug bool + + Req *http.Request + Rec *httptest.ResponseRecorder +} + +// Debug sets the debugger. If true, the debugger will print the body of the response +// when the expected status is not received. +func (r *Resp) Debug(b bool) *Resp { + r.debug = b + return r +} + +// Expect allows the assertions against the raw Resp. +func (r *Resp) Expect(fn func(*Resp)) *Resp { + fn(r) + return r +} + +// ExpectStatus compares the expected status code against the recorded status code. +func (r *Resp) ExpectStatus(code int) *Resp { + r.t.Helper() + + if r.Rec.Code != code { + r.t.Errorf("unexpected status code: expected=%d got=%d", code, r.Rec.Code) + if r.debug { + r.t.Logf("body: %v", r.Rec.Body.String()) + } + } + return r +} + +// ExpectBody provides an assertion against the recorder body. +func (r *Resp) ExpectBody(fn func(body *bytes.Buffer)) *Resp { + fn(r.Rec.Body) + return r +} + +// ExpectHeaders asserts that multiple headers with values exist in the recorder. +func (r *Resp) ExpectHeaders(h map[string]string) *Resp { + for k, v := range h { + r.ExpectHeader(k, v) + } + + return r +} + +// ExpectHeader asserts that the header is in the recorder. +func (r *Resp) ExpectHeader(k, v string) *Resp { + r.t.Helper() + + vals, ok := r.Rec.Header()[k] + if !ok { + r.t.Errorf("did not find expected header: %q", k) + return r + } + + for _, vv := range vals { + if vv == v { + return r + } + } + r.t.Errorf("did not find expected value for header %q; got: %v", k, vals) + + return r +} + +func mustEncodeJSON(t testing.TB, v interface{}) *bytes.Buffer { + t.Helper() + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(v); err != nil { + t.Fatal(err) + } + return &buf +} diff --git a/pkg/testttp/http_test.go b/pkg/testttp/http_test.go new file mode 100644 index 0000000000..750dac0d56 --- /dev/null +++ b/pkg/testttp/http_test.go @@ -0,0 +1,174 @@ +package testttp_test + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "github.com/influxdata/influxdb/pkg/testttp" +) + +func TestHTTP(t *testing.T) { + svr := newMux() + t.Run("Delete", func(t *testing.T) { + testttp. + Delete(t, "/"). + Do(svr). + ExpectStatus(http.StatusNoContent) + }) + + t.Run("Get", func(t *testing.T) { + testttp. + Get(t, "/"). + Do(svr). + ExpectStatus(http.StatusOK). + ExpectBody(assertBody(t, http.MethodGet)) + }) + + t.Run("Patch", func(t *testing.T) { + testttp. + Patch(t, "/", nil). + Do(svr). + ExpectStatus(http.StatusPartialContent). + ExpectBody(assertBody(t, http.MethodPatch)) + }) + + t.Run("PatchJSON", func(t *testing.T) { + testttp. + PatchJSON(t, "/", map[string]string{"k": "t"}). + Do(svr). + ExpectStatus(http.StatusPartialContent). + ExpectBody(assertBody(t, http.MethodPatch)) + }) + + t.Run("Post", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + testttp. + Post(t, "/", nil). + Do(svr). + ExpectStatus(http.StatusCreated). + ExpectBody(assertBody(t, http.MethodPost)) + }) + + t.Run("with form values", func(t *testing.T) { + svr := http.NewServeMux() + svr.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.FormValue("key"))) + })) + + testttp. + Post(t, "/", nil). + SetFormValue("key", "val"). + Do(svr). + ExpectStatus(http.StatusOK). + ExpectBody(func(body *bytes.Buffer) { + if expected, got := "val", body.String(); expected != got { + t.Fatalf("did not get form value; expected=%q got=%q", expected, got) + } + }) + }) + }) + + t.Run("PostJSON", func(t *testing.T) { + testttp. + PostJSON(t, "/", map[string]string{"k": "v"}). + Do(svr). + ExpectStatus(http.StatusCreated). + ExpectBody(assertBody(t, http.MethodPost)) + }) + + t.Run("Put", func(t *testing.T) { + testttp. + Put(t, "/", nil). + Do(svr). + ExpectStatus(http.StatusAccepted). + ExpectBody(assertBody(t, http.MethodPut)) + }) + + t.Run("PutJSON", func(t *testing.T) { + testttp. + PutJSON(t, "/", map[string]string{"k": "t"}). + Do(svr). + ExpectStatus(http.StatusAccepted). + ExpectBody(assertBody(t, http.MethodPut)) + }) + + t.Run("Headers", func(t *testing.T) { + testttp. + Post(t, "/", strings.NewReader(`a: foo`)). + Headers("Content-Type", "text/yml"). + Do(svr). + Expect(func(resp *testttp.Resp) { + equals(t, "text/yml", resp.Req.Header.Get("Content-Type")) + }) + }) +} + +type foo struct { + Name, Thing, Method string +} + +func newMux() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + writeFn(w, req.Method, http.StatusOK) + case http.MethodPost: + writeFn(w, req.Method, http.StatusCreated) + case http.MethodPut: + writeFn(w, req.Method, http.StatusAccepted) + case http.MethodPatch: + writeFn(w, req.Method, http.StatusPartialContent) + case http.MethodDelete: + w.WriteHeader(http.StatusNoContent) + } + }) + return mux +} + +func assertBody(t *testing.T, method string) func(*bytes.Buffer) { + return func(buf *bytes.Buffer) { + var f foo + if err := json.NewDecoder(buf).Decode(&f); err != nil { + t.Fatal(err) + } + expected := foo{Name: "name", Thing: "thing", Method: method} + equals(t, expected, f) + } +} + +func writeFn(w http.ResponseWriter, method string, statusCode int) { + f := foo{Name: "name", Thing: "thing", Method: method} + r, err := encodeBuf(f) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(statusCode) + if _, err := io.Copy(w, r); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func equals(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + t.Errorf("expected: %v\tactual: %v", expected, actual) +} + +func encodeBuf(v interface{}) (io.Reader, error) { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(v); err != nil { + return nil, err + } + return &buf, nil +}