chore: Add kit (#21086)
* chore: pull in unchanged kit from v2 * chore: remove v2 from import paths * chore: update module paths and go.mod for kit * chore: remove kit/cli again, not needed in 1.xpull/21103/head
parent
fbfd4b4651
commit
78724e5c20
6
go.mod
6
go.mod
|
@ -12,10 +12,12 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1
|
||||
github.com/dgryski/go-bitstream v0.0.0-20180413035011-3522498ce2c8
|
||||
github.com/go-chi/chi v4.1.0+incompatible
|
||||
github.com/gogo/protobuf v1.3.1
|
||||
github.com/golang/snappy v0.0.1
|
||||
github.com/google/go-cmp v0.5.0
|
||||
github.com/influxdata/flux v0.111.0
|
||||
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69
|
||||
github.com/influxdata/influxql v1.1.1-0.20210223160523-b6ab99450c93
|
||||
github.com/influxdata/pkg-config v0.2.7
|
||||
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6
|
||||
|
@ -24,16 +26,20 @@ require (
|
|||
github.com/jwilder/encoding v0.0.0-20170811194829-b4e1701a28ef
|
||||
github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada
|
||||
github.com/mattn/go-isatty v0.0.12
|
||||
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5
|
||||
github.com/opentracing/opentracing-go v1.1.0
|
||||
github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.5.1
|
||||
github.com/prometheus/client_model v0.2.0
|
||||
github.com/prometheus/common v0.9.1
|
||||
github.com/prometheus/prometheus v0.0.0-20200609090129-a6600f564e3c
|
||||
github.com/retailnext/hllpp v1.0.1-0.20180308014038-101a6d2f8b52
|
||||
github.com/spf13/cast v1.3.0
|
||||
github.com/spf13/cobra v0.0.3
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/tinylib/msgp v1.1.0
|
||||
github.com/uber/jaeger-client-go v2.23.0+incompatible
|
||||
github.com/willf/bitset v1.1.9 // indirect
|
||||
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6
|
||||
go.uber.org/zap v1.14.1
|
||||
|
|
16
go.sum
16
go.sum
|
@ -152,6 +152,7 @@ github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4
|
|||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
|
||||
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w=
|
||||
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
|
@ -199,6 +200,7 @@ github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15 h1:nLPjjvpUAODOR
|
|||
github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15/go.mod h1:tPg4cp4nseejPd+UKxtCVQ2hUxNTZ7qQZJa7CLriIeo=
|
||||
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
|
||||
github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20=
|
||||
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
|
@ -208,6 +210,8 @@ github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd h1:r04M
|
|||
github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE=
|
||||
github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31 h1:gclg6gY70GLy3PbkQ1AERPfmLMMagS60DKF78eWwLn8=
|
||||
github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24=
|
||||
github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w=
|
||||
github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
|
@ -434,6 +438,8 @@ github.com/influxdata/flux v0.65.0 h1:57tk1Oo4gpGIDbV12vUAPCMtLtThhaXzub1XRIuqv6
|
|||
github.com/influxdata/flux v0.65.0/go.mod h1:BwN2XG2lMszOoquQaFdPET8FRQfrXiZsWmcMO9rkaVY=
|
||||
github.com/influxdata/flux v0.111.0 h1:27CNz0SbEofD9NzdwcdxRwGmuVSDSisVq4dOceB/KF0=
|
||||
github.com/influxdata/flux v0.111.0/go.mod h1:3TJtvbm/Kwuo5/PEo5P6HUzwVg4bXWkb2wPQHPtQdlU=
|
||||
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69 h1:WQsmW0fXO4ZE/lFGIE84G6rIV5SJN3P3sjIXAP1a8eU=
|
||||
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69/go.mod h1:pwymjR6SrP3gD3pRj9RJwdl1j5s3doEEV8gS4X9qSzA=
|
||||
github.com/influxdata/influxdb v1.8.0/go.mod h1:SIzcnsjaHRFpmlxpJ4S3NT64qtEKYweNTUMb/vh0OMQ=
|
||||
github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
|
||||
github.com/influxdata/influxql v1.1.0/go.mod h1:KpVI7okXjK6PRi3Z5B+mtKZli+R1DnZgb3N+tzevNgo=
|
||||
|
@ -447,7 +453,6 @@ github.com/influxdata/pkg-config v0.2.7/go.mod h1:EMS7Ll0S4qkzDk53XS3Z72/egBsPIn
|
|||
github.com/influxdata/promql/v2 v2.12.0/go.mod h1:fxOPu+DY0bqCTCECchSRtWfc+0X19ybifQhZoQNF5D8=
|
||||
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6 h1:UzJnB7VRL4PSkUJHwsyzseGOmrO/r4yA+AuxGJxiZmA=
|
||||
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6/go.mod h1:bSgUQ7q5ZLSO+bKBGqJiCBGAl+9DxyW63zLTujjUlOE=
|
||||
github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9 h1:MHTrDWmQpHq/hkq+7cw9oYAt2PqUw52TZazRA0N7PGE=
|
||||
github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9/go.mod h1:Js0mqiSBE6Ffsg94weZZ2c+v/ciT8QRHFOap7EKDrR0=
|
||||
github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b h1:i44CesU68ZBRvtCjBi3QSosCIKrjmMbYlQMFAwVLds4=
|
||||
github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b/go.mod h1:Z0kXnxzbTC2qrx4NaIzYkE1k66+6oEDQTvL95hQFh5Y=
|
||||
|
@ -540,6 +545,8 @@ github.com/miekg/dns v1.1.22/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju
|
|||
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
||||
github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg=
|
||||
github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5 h1:pXqZHmHOz6LN+zbbUgqyGgAWRnnZEI40IzG3tMsXcSI=
|
||||
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5/go.mod h1:JWhYAp2EXqUtsxTKdeGlY8Wp44M7VxThC9FEoNGi2IE=
|
||||
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
|
||||
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
|
@ -550,6 +557,7 @@ github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS4
|
|||
github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY=
|
||||
github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mitchellh/mapstructure v1.2.2 h1:dxe5oCinTXiTIcfgmZecdCzPmAJKd46KsCWc35r0TV4=
|
||||
github.com/mitchellh/mapstructure v1.2.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
@ -596,9 +604,9 @@ github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh
|
|||
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
|
||||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/paulbellamy/ratecounter v0.2.0 h1:2L/RhJq+HA8gBQImDXtLPrDXK5qAj6ozWVK/zFXVJGs=
|
||||
github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChlfo5C6hzIHwPqfFE=
|
||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||
github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfSeg=
|
||||
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
|
||||
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
|
||||
|
@ -690,6 +698,7 @@ github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4k
|
|||
github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc=
|
||||
github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
|
||||
github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8=
|
||||
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
|
@ -722,7 +731,9 @@ github.com/uber-go/tally v3.3.15+incompatible h1:9hLSgNBP28CjIaDmAuRTq9qV+UZY+9P
|
|||
github.com/uber-go/tally v3.3.15+incompatible/go.mod h1:YDTIBxdXyOU/sCWilKB4bgyufu1cEi0jdVnRdxvjnmU=
|
||||
github.com/uber/athenadriver v1.1.4 h1:k6k0RBeXjR7oZ8NO557MsRw3eX1cc/9B0GNx+W9eHiQ=
|
||||
github.com/uber/athenadriver v1.1.4/go.mod h1:tQjho4NzXw55LGfSZEcETuYydpY1vtmixUabHkC1K/E=
|
||||
github.com/uber/jaeger-client-go v2.23.0+incompatible h1:o2g11IUBdEsSZVzF3k7+bahLmxRP/dbOoW4zQ30UlKE=
|
||||
github.com/uber/jaeger-client-go v2.23.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
|
||||
github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw=
|
||||
github.com/uber/jaeger-lib v2.2.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
|
||||
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
|
||||
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
|
||||
|
@ -1101,6 +1112,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
// Package check standardizes /health and /ready endpoints.
|
||||
// This allows you to easily know when your server is ready and healthy.
|
||||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Status string to indicate the overall status of the check.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusFail indicates a specific check has failed.
|
||||
StatusFail Status = "fail"
|
||||
// StatusPass indicates a specific check has passed.
|
||||
StatusPass Status = "pass"
|
||||
|
||||
// DefaultCheckName is the name of the default checker.
|
||||
DefaultCheckName = "internal"
|
||||
)
|
||||
|
||||
// Check wraps a map of service names to status checkers.
|
||||
type Check struct {
|
||||
healthChecks []Checker
|
||||
readyChecks []Checker
|
||||
healthOverride override
|
||||
readyOverride override
|
||||
|
||||
passthroughHandler http.Handler
|
||||
}
|
||||
|
||||
// Checker indicates a service whose health can be checked.
|
||||
type Checker interface {
|
||||
Check(ctx context.Context) Response
|
||||
}
|
||||
|
||||
// NewCheck returns a Health with a default checker.
|
||||
func NewCheck() *Check {
|
||||
ch := &Check{}
|
||||
ch.healthOverride.disable()
|
||||
ch.readyOverride.disable()
|
||||
return ch
|
||||
}
|
||||
|
||||
// AddHealthCheck adds the check to the list of ready checks.
|
||||
// If c is a NamedChecker, the name will be added.
|
||||
func (c *Check) AddHealthCheck(check Checker) {
|
||||
if nc, ok := check.(NamedChecker); ok {
|
||||
c.healthChecks = append(c.healthChecks, Named(nc.CheckName(), nc))
|
||||
} else {
|
||||
c.healthChecks = append(c.healthChecks, check)
|
||||
}
|
||||
}
|
||||
|
||||
// AddReadyCheck adds the check to the list of ready checks.
|
||||
// If c is a NamedChecker, the name will be added.
|
||||
func (c *Check) AddReadyCheck(check Checker) {
|
||||
if nc, ok := check.(NamedChecker); ok {
|
||||
c.readyChecks = append(c.readyChecks, Named(nc.CheckName(), nc))
|
||||
} else {
|
||||
c.readyChecks = append(c.readyChecks, check)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckHealth evaluates c's set of health checks and returns a populated Response.
|
||||
func (c *Check) CheckHealth(ctx context.Context) Response {
|
||||
response := Response{
|
||||
Name: "Health",
|
||||
Status: StatusPass,
|
||||
Checks: make(Responses, len(c.healthChecks)),
|
||||
}
|
||||
|
||||
status, overriding := c.healthOverride.get()
|
||||
if overriding {
|
||||
response.Status = status
|
||||
overrideResponse := Response{
|
||||
Name: "manual-override",
|
||||
Message: "health manually overridden",
|
||||
}
|
||||
response.Checks = append(response.Checks, overrideResponse)
|
||||
}
|
||||
for i, ch := range c.healthChecks {
|
||||
resp := ch.Check(ctx)
|
||||
if resp.Status != StatusPass && !overriding {
|
||||
response.Status = resp.Status
|
||||
}
|
||||
response.Checks[i] = resp
|
||||
}
|
||||
sort.Sort(response.Checks)
|
||||
return response
|
||||
}
|
||||
|
||||
// CheckReady evaluates c's set of ready checks and returns a populated Response.
|
||||
func (c *Check) CheckReady(ctx context.Context) Response {
|
||||
response := Response{
|
||||
Name: "Ready",
|
||||
Status: StatusPass,
|
||||
Checks: make(Responses, len(c.readyChecks)),
|
||||
}
|
||||
|
||||
status, overriding := c.readyOverride.get()
|
||||
if overriding {
|
||||
response.Status = status
|
||||
overrideResponse := Response{
|
||||
Name: "manual-override",
|
||||
Message: "ready manually overridden",
|
||||
}
|
||||
response.Checks = append(response.Checks, overrideResponse)
|
||||
}
|
||||
for i, c := range c.readyChecks {
|
||||
resp := c.Check(ctx)
|
||||
if resp.Status != StatusPass && !overriding {
|
||||
response.Status = resp.Status
|
||||
}
|
||||
response.Checks[i] = resp
|
||||
}
|
||||
sort.Sort(response.Checks)
|
||||
return response
|
||||
}
|
||||
|
||||
// SetPassthrough allows you to set a handler to use if the request is not a ready or health check.
|
||||
// This can be useful if you intend to use this as a middleware.
|
||||
func (c *Check) SetPassthrough(h http.Handler) {
|
||||
c.passthroughHandler = h
|
||||
}
|
||||
|
||||
// ServeHTTP serves /ready and /health requests with the respective checks.
|
||||
func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
const (
|
||||
pathReady = "/ready"
|
||||
pathHealth = "/health"
|
||||
queryForce = "force"
|
||||
)
|
||||
|
||||
path := r.URL.Path
|
||||
|
||||
// Allow requests not intended for checks to pass through.
|
||||
if path != pathReady && path != pathHealth {
|
||||
if c.passthroughHandler != nil {
|
||||
c.passthroughHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// We can't handle this request.
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
query := r.URL.Query()
|
||||
|
||||
switch path {
|
||||
case pathReady:
|
||||
switch query.Get(queryForce) {
|
||||
case "true":
|
||||
switch query.Get("ready") {
|
||||
case "true":
|
||||
c.readyOverride.enable(StatusPass)
|
||||
case "false":
|
||||
c.readyOverride.enable(StatusFail)
|
||||
}
|
||||
case "false":
|
||||
c.readyOverride.disable()
|
||||
}
|
||||
writeResponse(w, c.CheckReady(ctx))
|
||||
case pathHealth:
|
||||
switch query.Get(queryForce) {
|
||||
case "true":
|
||||
switch query.Get("healthy") {
|
||||
case "true":
|
||||
c.healthOverride.enable(StatusPass)
|
||||
case "false":
|
||||
c.healthOverride.enable(StatusFail)
|
||||
}
|
||||
case "false":
|
||||
c.healthOverride.disable()
|
||||
}
|
||||
writeResponse(w, c.CheckHealth(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// writeResponse writes a Response to the wire as JSON. The HTTP status code
|
||||
// accompanying the payload is the primary means for signaling the status of the
|
||||
// checks. The possible status codes are:
|
||||
//
|
||||
// - 200 OK: All checks pass.
|
||||
// - 503 Service Unavailable: Some checks are failing.
|
||||
// - 500 Internal Server Error: There was a problem serializing the Response.
|
||||
func writeResponse(w http.ResponseWriter, resp Response) {
|
||||
status := http.StatusOK
|
||||
if resp.Status == StatusFail {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
msg, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
msg = []byte(`{"message": "error marshaling response", "status": "fail"}`)
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
w.WriteHeader(status)
|
||||
fmt.Fprintln(w, string(msg))
|
||||
}
|
||||
|
||||
// override is a manual override for an entire group of checks.
|
||||
type override struct {
|
||||
mtx sync.Mutex
|
||||
status Status
|
||||
active bool
|
||||
}
|
||||
|
||||
// get returns the Status of an override as well as whether or not an override
|
||||
// is currently active.
|
||||
func (m *override) get() (Status, bool) {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
return m.status, m.active
|
||||
}
|
||||
|
||||
// disable disables the override.
|
||||
func (m *override) disable() {
|
||||
m.mtx.Lock()
|
||||
m.active = false
|
||||
m.status = StatusFail
|
||||
m.mtx.Unlock()
|
||||
}
|
||||
|
||||
// enable turns on the override and establishes a specific Status for which to.
|
||||
func (m *override) enable(s Status) {
|
||||
m.mtx.Lock()
|
||||
m.active = true
|
||||
m.status = s
|
||||
m.mtx.Unlock()
|
||||
}
|
|
@ -0,0 +1,535 @@
|
|||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEmptyCheck(t *testing.T) {
|
||||
c := NewCheck()
|
||||
resp := c.CheckReady(context.Background())
|
||||
if len(resp.Checks) > 0 {
|
||||
t.Errorf("no checks added but %d returned", len(resp.Checks))
|
||||
}
|
||||
|
||||
if resp.Name != "Ready" {
|
||||
t.Errorf("expected: \"Ready\", got: %q", resp.Name)
|
||||
}
|
||||
|
||||
if resp.Status != StatusPass {
|
||||
t.Errorf("expected: %q, got: %q", StatusPass, resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddHealthCheck(t *testing.T) {
|
||||
h := NewCheck()
|
||||
h.AddHealthCheck(Named("awesome", ErrCheck(func() error {
|
||||
return nil
|
||||
})))
|
||||
r := h.CheckHealth(context.Background())
|
||||
if r.Status != StatusPass {
|
||||
t.Error("Health should fail because one of the check is unhealthy")
|
||||
}
|
||||
|
||||
if len(r.Checks) != 1 {
|
||||
t.Fatalf("check not in results: %+v", r.Checks)
|
||||
}
|
||||
|
||||
v := r.Checks[0]
|
||||
if v.Status != StatusPass {
|
||||
t.Errorf("the added check should be pass not %q.", v.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUnHealthyCheck(t *testing.T) {
|
||||
h := NewCheck()
|
||||
h.AddHealthCheck(Named("failure", ErrCheck(func() error {
|
||||
return errors.New("Oops! I am sorry")
|
||||
})))
|
||||
r := h.CheckHealth(context.Background())
|
||||
if r.Status != StatusFail {
|
||||
t.Error("Health should fail because one of the check is unhealthy")
|
||||
}
|
||||
|
||||
if len(r.Checks) != 1 {
|
||||
t.Fatal("check not in results")
|
||||
}
|
||||
|
||||
v := r.Checks[0]
|
||||
if v.Status != StatusFail {
|
||||
t.Errorf("the added check should be fail not %s.", v.Status)
|
||||
}
|
||||
if v.Message != "Oops! I am sorry" {
|
||||
t.Errorf(
|
||||
"the error should be 'Oops! I am sorry' not %s.",
|
||||
v.Message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func buildCheckWithServer() (*Check, *httptest.Server) {
|
||||
c := NewCheck()
|
||||
return c, httptest.NewServer(c)
|
||||
}
|
||||
|
||||
type mockCheck struct {
|
||||
status Status
|
||||
name string
|
||||
}
|
||||
|
||||
func (m mockCheck) Check(_ context.Context) Response {
|
||||
return Response{
|
||||
Name: m.name,
|
||||
Status: m.status,
|
||||
}
|
||||
}
|
||||
|
||||
func mockPass(name string) Checker {
|
||||
return mockCheck{status: StatusPass, name: name}
|
||||
}
|
||||
|
||||
func mockFail(name string) Checker {
|
||||
return mockCheck{status: StatusFail, name: name}
|
||||
}
|
||||
|
||||
func respBuilder(body io.ReadCloser) (*Response, error) {
|
||||
defer body.Close()
|
||||
d := json.NewDecoder(body)
|
||||
r := &Response{}
|
||||
return r, d.Decode(r)
|
||||
}
|
||||
|
||||
func TestBasicHTTPHandler(t *testing.T) {
|
||||
_, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := &Response{
|
||||
Name: "Ready",
|
||||
Status: StatusPass,
|
||||
}
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthSorting(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddHealthCheck(mockPass("a"))
|
||||
c.AddHealthCheck(mockPass("c"))
|
||||
c.AddHealthCheck(mockPass("b"))
|
||||
c.AddHealthCheck(mockFail("k"))
|
||||
c.AddHealthCheck(mockFail("b"))
|
||||
|
||||
resp, err := http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Health",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "b", Status: "fail"},
|
||||
Response{Name: "k", Status: "fail"},
|
||||
Response{Name: "a", Status: "pass"},
|
||||
Response{Name: "b", Status: "pass"},
|
||||
Response{Name: "c", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForceHealthy(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddHealthCheck(mockFail("a"))
|
||||
|
||||
_, err := http.Get(ts.URL + "/health?force=true&healthy=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Health",
|
||||
Status: "pass",
|
||||
Checks: Responses{
|
||||
Response{Name: "manual-override", Message: "health manually overridden"},
|
||||
Response{Name: "a", Status: "fail"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
|
||||
_, err = http.Get(ts.URL + "/health?force=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected = &Response{
|
||||
Name: "Health",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "a", Status: "fail"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err = respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForceUnhealthy(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddHealthCheck(mockPass("a"))
|
||||
|
||||
_, err := http.Get(ts.URL + "/health?force=true&healthy=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Health",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "manual-override", Message: "health manually overridden"},
|
||||
Response{Name: "a", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
|
||||
_, err = http.Get(ts.URL + "/health?force=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected = &Response{
|
||||
Name: "Health",
|
||||
Status: "pass",
|
||||
Checks: Responses{
|
||||
Response{Name: "a", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err = respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForceReady(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddReadyCheck(mockFail("a"))
|
||||
|
||||
_, err := http.Get(ts.URL + "/ready?force=true&ready=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Ready",
|
||||
Status: "pass",
|
||||
Checks: Responses{
|
||||
Response{Name: "manual-override", Message: "ready manually overridden"},
|
||||
Response{Name: "a", Status: "fail"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
|
||||
_, err = http.Get(ts.URL + "/ready?force=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected = &Response{
|
||||
Name: "Ready",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "a", Status: "fail"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err = respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForceNotReady(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddReadyCheck(mockPass("a"))
|
||||
|
||||
_, err := http.Get(ts.URL + "/ready?force=true&ready=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Ready",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "manual-override", Message: "ready manually overridden"},
|
||||
Response{Name: "a", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
|
||||
_, err = http.Get(ts.URL + "/ready?force=false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected = &Response{
|
||||
Name: "Ready",
|
||||
Status: "pass",
|
||||
Checks: Responses{
|
||||
Response{Name: "a", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err = respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoCrossOver(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
c.AddHealthCheck(mockPass("a"))
|
||||
c.AddHealthCheck(mockPass("c"))
|
||||
c.AddReadyCheck(mockPass("b"))
|
||||
c.AddReadyCheck(mockFail("k"))
|
||||
c.AddHealthCheck(mockFail("b"))
|
||||
|
||||
resp, err := http.Get(ts.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err := respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Response{
|
||||
Name: "Health",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "b", Status: "fail"},
|
||||
Response{Name: "a", Status: "pass"},
|
||||
Response{Name: "c", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
|
||||
resp, err = http.Get(ts.URL + "/ready")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual, err = respBuilder(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected = &Response{
|
||||
Name: "Ready",
|
||||
Status: "fail",
|
||||
Checks: Responses{
|
||||
Response{Name: "k", Status: "fail"},
|
||||
Response{Name: "b", Status: "pass"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
c, ts := buildCheckWithServer()
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 404 {
|
||||
t.Fatalf("failed to error when no passthrough is present, status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
used := false
|
||||
s := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
used = true
|
||||
w.Write([]byte("hi"))
|
||||
})
|
||||
|
||||
c.SetPassthrough(s)
|
||||
|
||||
resp, err = http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("bad response code from passthrough, status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if !used {
|
||||
t.Fatal("passthrough server not used")
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleNewCheck() {
|
||||
// Run the default healthcheck. it always return 200. It is good if you
|
||||
// have a service without any dependency
|
||||
h := NewCheck()
|
||||
h.CheckHealth(context.Background())
|
||||
}
|
||||
|
||||
func ExampleCheck_CheckHealth() {
|
||||
h := NewCheck()
|
||||
h.AddHealthCheck(Named("google", CheckerFunc(func(ctx context.Context) Response {
|
||||
var r net.Resolver
|
||||
_, err := r.LookupHost(ctx, "google.com")
|
||||
if err != nil {
|
||||
return Error(err)
|
||||
}
|
||||
return Pass()
|
||||
})))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
h.CheckHealth(ctx)
|
||||
}
|
||||
|
||||
func ExampleCheck_ServeHTTP() {
|
||||
c := NewCheck()
|
||||
http.ListenAndServe(":6060", c)
|
||||
}
|
||||
|
||||
func ExampleCheck_SetPassthrough() {
|
||||
c := NewCheck()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello friends!"))
|
||||
})
|
||||
|
||||
c.SetPassthrough(http.DefaultServeMux)
|
||||
http.ListenAndServe(":6060", c)
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NamedChecker is a superset of Checker that also indicates the name of the service.
|
||||
// Prefer to implement NamedChecker if your service has a fixed name,
|
||||
// as opposed to calling *Health.AddNamed.
|
||||
type NamedChecker interface {
|
||||
Checker
|
||||
CheckName() string
|
||||
}
|
||||
|
||||
// CheckerFunc is an adapter of a plain func() error to the Checker interface.
|
||||
type CheckerFunc func(ctx context.Context) Response
|
||||
|
||||
// Check implements Checker.
|
||||
func (f CheckerFunc) Check(ctx context.Context) Response {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// Named returns a Checker that will attach a name to the Response from the check.
|
||||
// This way, it is possible to augment a Response with a human-readable name, but not have to encode
|
||||
// that logic in the actual check itself.
|
||||
func Named(name string, checker Checker) Checker {
|
||||
return CheckerFunc(func(ctx context.Context) Response {
|
||||
resp := checker.Check(ctx)
|
||||
resp.Name = name
|
||||
return resp
|
||||
})
|
||||
}
|
||||
|
||||
// NamedFunc is the same as Named except it takes a CheckerFunc.
|
||||
func NamedFunc(name string, fn CheckerFunc) Checker {
|
||||
return Named(name, fn)
|
||||
}
|
||||
|
||||
// ErrCheck will create a health checker that executes a function. If the function returns an error,
|
||||
// it will return an unhealthy response. Otherwise, it will be as if the Ok function was called.
|
||||
// Note: it is better to use CheckFunc, because with Check, the context is ignored.
|
||||
func ErrCheck(fn func() error) Checker {
|
||||
return CheckerFunc(func(_ context.Context) Response {
|
||||
if err := fn(); err != nil {
|
||||
return Error(err)
|
||||
}
|
||||
return Pass()
|
||||
})
|
||||
}
|
||||
|
||||
// Pass is a utility function to generate a passing status response with the default parameters.
|
||||
func Pass() Response {
|
||||
return Response{
|
||||
Status: StatusPass,
|
||||
}
|
||||
}
|
||||
|
||||
// Info is a utility function to generate a healthy status with a printf message.
|
||||
func Info(msg string, args ...interface{}) Response {
|
||||
return Response{
|
||||
Status: StatusPass,
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// Error is a utility function for creating a response from an error message.
|
||||
func Error(err error) Response {
|
||||
return Response{
|
||||
Status: StatusFail,
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package check
|
||||
|
||||
// Response is a result of a collection of health checks.
|
||||
type Response struct {
|
||||
Name string `json:"name"`
|
||||
Status Status `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Checks Responses `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
// HasCheck verifies whether the receiving Response has a check with the given name or not.
|
||||
func (r *Response) HasCheck(name string) bool {
|
||||
found := false
|
||||
for _, check := range r.Checks {
|
||||
if check.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
// Responses is a sortable collection of Response objects.
|
||||
type Responses []Response
|
||||
|
||||
func (r Responses) Len() int { return len(r) }
|
||||
|
||||
// Less defines the order in which responses are sorted.
|
||||
//
|
||||
// Failing responses are always sorted before passing responses. Responses with
|
||||
// the same status are then sorted according to the name of the check.
|
||||
func (r Responses) Less(i, j int) bool {
|
||||
if r[i].Status == r[j].Status {
|
||||
return r[i].Name < r[j].Name
|
||||
}
|
||||
return r[i].Status < r[j].Status
|
||||
}
|
||||
|
||||
func (r Responses) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
|
|
@ -0,0 +1,107 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TODO: move to base directory
|
||||
|
||||
const (
|
||||
// InternalError indicates an unexpected error condition.
|
||||
InternalError = 1
|
||||
// MalformedData indicates malformed input, such as unparsable JSON.
|
||||
MalformedData = 2
|
||||
// InvalidData indicates that data is well-formed, but invalid.
|
||||
InvalidData = 3
|
||||
// Forbidden indicates a forbidden operation.
|
||||
Forbidden = 4
|
||||
// NotFound indicates a resource was not found.
|
||||
NotFound = 5
|
||||
)
|
||||
|
||||
// Error indicates an error with a reference code and an HTTP status code.
|
||||
type Error struct {
|
||||
Reference int `json:"referenceCode"`
|
||||
Code int `json:"statusCode"`
|
||||
Err string `json:"err"`
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e Error) Error() string {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// Errorf constructs an Error with the given reference code and format.
|
||||
func Errorf(ref int, format string, i ...interface{}) error {
|
||||
return Error{
|
||||
Reference: ref,
|
||||
Err: fmt.Sprintf(format, i...),
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new error with a message and error code.
|
||||
func New(msg string, ref ...int) error {
|
||||
refCode := InternalError
|
||||
if len(ref) == 1 {
|
||||
refCode = ref[0]
|
||||
}
|
||||
return Error{
|
||||
Reference: refCode,
|
||||
Err: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func Wrap(err error, msg string, ref ...int) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
e, ok := err.(Error)
|
||||
if ok {
|
||||
refCode := e.Reference
|
||||
if len(ref) == 1 {
|
||||
refCode = ref[0]
|
||||
}
|
||||
return Error{
|
||||
Reference: refCode,
|
||||
Code: e.Code,
|
||||
Err: fmt.Sprintf("%s: %s", msg, e.Err),
|
||||
}
|
||||
}
|
||||
refCode := InternalError
|
||||
if len(ref) == 1 {
|
||||
refCode = ref[0]
|
||||
}
|
||||
return Error{
|
||||
Reference: refCode,
|
||||
Err: fmt.Sprintf("%s: %s", msg, err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
// InternalErrorf constructs an InternalError with the given format.
|
||||
func InternalErrorf(format string, i ...interface{}) error {
|
||||
return Errorf(InternalError, format, i...)
|
||||
}
|
||||
|
||||
// MalformedDataf constructs a MalformedData error with the given format.
|
||||
func MalformedDataf(format string, i ...interface{}) error {
|
||||
return Errorf(MalformedData, format, i...)
|
||||
}
|
||||
|
||||
// InvalidDataf constructs an InvalidData error with the given format.
|
||||
func InvalidDataf(format string, i ...interface{}) error {
|
||||
return Errorf(InvalidData, format, i...)
|
||||
}
|
||||
|
||||
// Forbiddenf constructs a Forbidden error with the given format.
|
||||
func Forbiddenf(format string, i ...interface{}) error {
|
||||
return Errorf(Forbidden, format, i...)
|
||||
}
|
||||
|
||||
func BadRequestError(msg string) error {
|
||||
return Error{
|
||||
Reference: InvalidData,
|
||||
Code: http.StatusBadRequest,
|
||||
Err: msg,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// List represents a list of errors.
|
||||
type List struct {
|
||||
errs []error
|
||||
err error // cached error
|
||||
}
|
||||
|
||||
// Append adds err to the errors list.
|
||||
func (l *List) Append(err error) {
|
||||
l.errs = append(l.errs, err)
|
||||
l.err = nil
|
||||
}
|
||||
|
||||
// AppendString adds a new error that formats as the given text.
|
||||
func (l *List) AppendString(text string) {
|
||||
l.errs = append(l.errs, errors.New(text))
|
||||
l.err = nil
|
||||
}
|
||||
|
||||
// Clear removes all the previously appended errors from the list.
|
||||
func (l *List) Clear() {
|
||||
for i := range l.errs {
|
||||
l.errs[i] = nil
|
||||
}
|
||||
l.errs = l.errs[:0]
|
||||
l.err = nil
|
||||
}
|
||||
|
||||
// Err returns an error composed of the list of errors, separated by a new line, or nil if no errors
|
||||
// were appended.
|
||||
func (l *List) Err() error {
|
||||
if len(l.errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if l.err != nil {
|
||||
switch len(l.errs) {
|
||||
case 1:
|
||||
l.err = l.errs[0]
|
||||
|
||||
default:
|
||||
var sb strings.Builder
|
||||
sb.WriteString(l.errs[0].Error())
|
||||
for _, err := range l.errs[1:] {
|
||||
sb.WriteRune('\n')
|
||||
sb.WriteString(err.Error())
|
||||
}
|
||||
l.err = errors.New(sb.String())
|
||||
}
|
||||
}
|
||||
|
||||
return l.err
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/Masterminds/sprig"
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const tmpl = `// Code generated by the feature package; DO NOT EDIT.
|
||||
|
||||
package feature
|
||||
|
||||
{{ .Qualify | import }}
|
||||
|
||||
{{ range $_, $flag := .Flags }}
|
||||
var {{ $flag.Key }} = {{ $.Qualify | package }}{{ $flag.Default | maker }}(
|
||||
{{ $flag.Name | quote }},
|
||||
{{ $flag.Key | quote }},
|
||||
{{ $flag.Contact | quote }},
|
||||
{{ $flag.Default | conditionalQuote }},
|
||||
{{ $.Qualify | package }}{{ $flag.Lifetime | lifetime }},
|
||||
{{ $flag.Expose }},
|
||||
)
|
||||
|
||||
// {{ $flag.Name | replace " " "_" | camelcase }} - {{ $flag.Description }}
|
||||
func {{ $flag.Name | replace " " "_" | camelcase }}() {{ $.Qualify | package }}{{ $flag.Default | flagType }} {
|
||||
return {{ $flag.Key }}
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
var all = []{{ .Qualify | package }}Flag{
|
||||
{{ range $_, $flag := .Flags }} {{ $flag.Key }},
|
||||
{{ end }}}
|
||||
|
||||
var byKey = map[string]{{ $.Qualify | package }}Flag{
|
||||
{{ range $_, $flag := .Flags }} {{ $flag.Key | quote }}: {{ $flag.Key }},
|
||||
{{ end }}}
|
||||
`
|
||||
|
||||
type flagConfig struct {
|
||||
Name string
|
||||
Description string
|
||||
Key string
|
||||
Default interface{}
|
||||
Contact string
|
||||
Lifetime feature.Lifetime
|
||||
Expose bool
|
||||
}
|
||||
|
||||
func (f flagConfig) Valid() error {
|
||||
var problems []string
|
||||
if f.Key == "" {
|
||||
problems = append(problems, "missing key")
|
||||
}
|
||||
if f.Contact == "" {
|
||||
problems = append(problems, "missing contact")
|
||||
}
|
||||
if f.Default == nil {
|
||||
problems = append(problems, "missing default")
|
||||
}
|
||||
if f.Description == "" {
|
||||
problems = append(problems, "missing description")
|
||||
}
|
||||
|
||||
if len(problems) > 0 {
|
||||
name := f.Name
|
||||
if name == "" {
|
||||
if f.Key != "" {
|
||||
name = f.Key
|
||||
} else {
|
||||
name = "anonymous"
|
||||
}
|
||||
}
|
||||
// e.g. "my flag: missing key; missing default"
|
||||
return fmt.Errorf("%s: %s\n", name, strings.Join(problems, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type flagValidationError struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
func newFlagValidationError(errs []error) *flagValidationError {
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &flagValidationError{errs}
|
||||
}
|
||||
|
||||
func (e *flagValidationError) Error() string {
|
||||
var s strings.Builder
|
||||
s.WriteString("flag validation error: \n")
|
||||
for _, err := range e.errs {
|
||||
s.WriteString(err.Error())
|
||||
}
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func validate(flags []flagConfig) error {
|
||||
var (
|
||||
errs []error
|
||||
seen = make(map[string]bool, len(flags))
|
||||
)
|
||||
for _, flag := range flags {
|
||||
if err := flag.Valid(); err != nil {
|
||||
errs = append(errs, err)
|
||||
} else if _, repeated := seen[flag.Key]; repeated {
|
||||
errs = append(errs, fmt.Errorf("duplicate flag key '%s'\n", flag.Key))
|
||||
}
|
||||
seen[flag.Key] = true
|
||||
}
|
||||
if len(errs) != 0 {
|
||||
return newFlagValidationError(errs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var argv = struct {
|
||||
in, out *string
|
||||
qualify *bool
|
||||
}{
|
||||
in: flag.String("in", "", "flag configuration path"),
|
||||
out: flag.String("out", "", "flag generation destination path"),
|
||||
qualify: flag.Bool("qualify", false, "qualify types with imported package name"),
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func run() error {
|
||||
flag.Parse()
|
||||
|
||||
in, err := os.Open(*argv.in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
configuration, err := ioutil.ReadAll(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var flags []flagConfig
|
||||
err = yaml.Unmarshal(configuration, &flags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = validate(flags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t, err := template.New("flags").Funcs(templateFunctions()).Parse(tmpl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := os.Create(*argv.out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
var (
|
||||
buf = new(bytes.Buffer)
|
||||
vars = struct {
|
||||
Qualify bool
|
||||
Flags []flagConfig
|
||||
}{
|
||||
Qualify: *argv.qualify,
|
||||
Flags: flags,
|
||||
}
|
||||
)
|
||||
if err := t.Execute(buf, vars); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
raw, err := ioutil.ReadAll(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
formatted, err := format.Source(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = out.Write(formatted)
|
||||
return err
|
||||
}
|
||||
|
||||
func templateFunctions() template.FuncMap {
|
||||
functions := sprig.TxtFuncMap()
|
||||
|
||||
functions["lifetime"] = func(t interface{}) string {
|
||||
switch t {
|
||||
case feature.Permanent:
|
||||
return "Permanent"
|
||||
default:
|
||||
return "Temporary"
|
||||
}
|
||||
}
|
||||
|
||||
functions["conditionalQuote"] = func(t interface{}) string {
|
||||
switch t.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("%q", t)
|
||||
default:
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
}
|
||||
|
||||
functions["flagType"] = func(t interface{}) string {
|
||||
switch t.(type) {
|
||||
case bool:
|
||||
return "BoolFlag"
|
||||
case float64:
|
||||
return "FloatFlag"
|
||||
case int:
|
||||
return "IntFlag"
|
||||
default:
|
||||
return "StringFlag"
|
||||
}
|
||||
}
|
||||
|
||||
functions["maker"] = func(t interface{}) string {
|
||||
switch t.(type) {
|
||||
case bool:
|
||||
return "MakeBoolFlag"
|
||||
case float64:
|
||||
return "MakeFloatFlag"
|
||||
case int:
|
||||
return "MakeIntFlag"
|
||||
default:
|
||||
return "MakeStringFlag"
|
||||
}
|
||||
}
|
||||
|
||||
functions["package"] = func(t interface{}) string {
|
||||
if t.(bool) {
|
||||
return "feature."
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
functions["import"] = func(t interface{}) string {
|
||||
if t.(bool) {
|
||||
return "import \"github.com/influxdata/influxdb/kit/feature\""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return functions
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
// Package feature provides feature flagging capabilities for InfluxDB servers.
|
||||
// This document describes this package and how it is used to control
|
||||
// experimental features in `influxd`.
|
||||
//
|
||||
// Flags are configured in `flags.yml` at the top of this repository.
|
||||
// Running `make flags` generates Go code based on this configuration
|
||||
// to programmatically test flag values in a given request context.
|
||||
// Boolean flags are the most common case, but integers, floats and
|
||||
// strings are supported for more complicated experiments.
|
||||
//
|
||||
// The `Flagger` interface is the crux of this package.
|
||||
// It computes a map of feature flag values for a given request context.
|
||||
// The default implementation always returns the flag default configured
|
||||
// in `flags.yml`. The override implementation allows an operator to
|
||||
// override feature flag defaults at startup. Changing these overrides
|
||||
// requires a restart.
|
||||
//
|
||||
// In `influxd`, a `Flagger` instance is provided to a `Handler` middleware
|
||||
// configured to intercept all API requests and annotate their request context
|
||||
// with a map of feature flags.
|
||||
//
|
||||
// A flag can opt in to be exposed externally in `flags.yml`. If exposed,
|
||||
// this flag will be included in the response from the `/api/v2/flags`
|
||||
// endpoint. This allows the UI and other API clients to control their
|
||||
// behavior according to the flag in addition to the server itself.
|
||||
//
|
||||
// A concrete example to illustrate the above:
|
||||
//
|
||||
// I have a feature called "My Feature" that will involve turning on new code
|
||||
// in both the UI and the server.
|
||||
//
|
||||
// First, I add an entry to `flags.yml`.
|
||||
//
|
||||
// ```yaml
|
||||
// - name: My Feature
|
||||
// description: My feature is awesome
|
||||
// key: myFeature
|
||||
// default: false
|
||||
// expose: true
|
||||
// contact: My Name
|
||||
// ```
|
||||
//
|
||||
// My flag type is inferred to be boolean by my default of `false` when I run
|
||||
// `make flags` and the `feature` package now includes `func MyFeature() BoolFlag`.
|
||||
//
|
||||
// I use this to control my backend code with
|
||||
//
|
||||
// ```go
|
||||
// if feature.MyFeature.Enabled(ctx) {
|
||||
// // new code...
|
||||
// } else {
|
||||
// // new code...
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// and the `/api/v2/flags` response provides the same information to the frontend.
|
||||
//
|
||||
// ```json
|
||||
// {
|
||||
// "myFeature": false
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// While `false` by default, I can turn on my experimental feature by starting
|
||||
// my server with a flag override.
|
||||
//
|
||||
// ```
|
||||
// env INFLUXD_FEATURE_FLAGS="{\"flag1\":\value1\",\"key2\":\"value2\"}" influxd
|
||||
// ```
|
||||
//
|
||||
// ```
|
||||
// influxd --feature-flags flag1=value1,flag2=value2
|
||||
// ```
|
||||
//
|
||||
package feature
|
|
@ -0,0 +1,140 @@
|
|||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const featureContextKey contextKey = "influx/feature/v1"
|
||||
|
||||
// Flagger returns flag values.
|
||||
type Flagger interface {
|
||||
// Flags returns a map of flag keys to flag values.
|
||||
//
|
||||
// If an authorization is present on the context, it may be used to compute flag
|
||||
// values according to the affiliated user ID and its organization and other mappings.
|
||||
// Otherwise, they should be computed generally or return a default.
|
||||
//
|
||||
// One or more flags may be provided to restrict the results.
|
||||
// Otherwise, all flags should be computed.
|
||||
Flags(context.Context, ...Flag) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// Annotate the context with a map computed of computed flags.
|
||||
func Annotate(ctx context.Context, f Flagger, flags ...Flag) (context.Context, error) {
|
||||
computed, err := f.Flags(ctx, flags...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
span := opentracing.SpanFromContext(ctx)
|
||||
if span != nil {
|
||||
for k, v := range computed {
|
||||
span.LogKV(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, featureContextKey, computed), nil
|
||||
}
|
||||
|
||||
// FlagsFromContext returns the map of flags attached to the context
|
||||
// by Annotate, or nil if none is found.
|
||||
func FlagsFromContext(ctx context.Context) map[string]interface{} {
|
||||
v, ok := ctx.Value(featureContextKey).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type ByKeyFn func(string) (Flag, bool)
|
||||
|
||||
// ExposedFlagsFromContext returns the filtered map of exposed flags attached
|
||||
// to the context by Annotate, or nil if none is found.
|
||||
func ExposedFlagsFromContext(ctx context.Context, byKey ByKeyFn) map[string]interface{} {
|
||||
m := FlagsFromContext(ctx)
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make(map[string]interface{})
|
||||
for k, v := range m {
|
||||
if flag, found := byKey(k); found && flag.Expose() {
|
||||
filtered[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// Lifetime represents the intended lifetime of the feature flag.
|
||||
//
|
||||
// The zero value is Temporary, the most common case, but Permanent
|
||||
// is included to mark special cases where a flag is not intended
|
||||
// to be removed, e.g. enabling debug tracing for an organization.
|
||||
//
|
||||
// TODO(gavincabbage): This may become a stale date, which can then
|
||||
// be used to trigger a notification to the contact when the flag
|
||||
// has become stale, to encourage flag cleanup.
|
||||
type Lifetime int
|
||||
|
||||
const (
|
||||
// Temporary indicates a flag is intended to be removed after a feature is no longer in development.
|
||||
Temporary Lifetime = iota
|
||||
// Permanent indicates a flag is not intended to be removed.
|
||||
Permanent
|
||||
)
|
||||
|
||||
// UnmarshalYAML implements yaml.Unmarshaler and interprets a case-insensitive text
|
||||
// representation as a lifetime constant.
|
||||
func (l *Lifetime) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch strings.ToLower(s) {
|
||||
case "permanent":
|
||||
*l = Permanent
|
||||
default:
|
||||
*l = Temporary
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type defaultFlagger struct{}
|
||||
|
||||
// DefaultFlagger returns a flagger that always returns default values.
|
||||
func DefaultFlagger() Flagger {
|
||||
return &defaultFlagger{}
|
||||
}
|
||||
|
||||
// Flags returns a map of default values. It never returns an error.
|
||||
func (*defaultFlagger) Flags(_ context.Context, flags ...Flag) (map[string]interface{}, error) {
|
||||
if len(flags) == 0 {
|
||||
flags = Flags()
|
||||
}
|
||||
|
||||
m := make(map[string]interface{}, len(flags))
|
||||
for _, flag := range flags {
|
||||
m[flag.Key()] = flag.Default()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Flags returns all feature flags.
|
||||
func Flags() []Flag {
|
||||
return all
|
||||
}
|
||||
|
||||
// ByKey returns the Flag corresponding to the given key.
|
||||
func ByKey(k string) (Flag, bool) {
|
||||
v, found := byKey[k]
|
||||
return v, found
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
package feature_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
)
|
||||
|
||||
func Test_feature(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
flag feature.Flag
|
||||
err error
|
||||
values map[string]interface{}
|
||||
ctx context.Context
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "bool happy path",
|
||||
flag: newFlag("test", false),
|
||||
values: map[string]interface{}{
|
||||
"test": true,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "int happy path",
|
||||
flag: newFlag("test", 0),
|
||||
values: map[string]interface{}{
|
||||
"test": int32(42),
|
||||
},
|
||||
expected: int32(42),
|
||||
},
|
||||
{
|
||||
name: "float happy path",
|
||||
flag: newFlag("test", 0.0),
|
||||
values: map[string]interface{}{
|
||||
"test": 42.42,
|
||||
},
|
||||
expected: 42.42,
|
||||
},
|
||||
{
|
||||
name: "string happy path",
|
||||
flag: newFlag("test", ""),
|
||||
values: map[string]interface{}{
|
||||
"test": "restaurantattheendoftheuniverse",
|
||||
},
|
||||
expected: "restaurantattheendoftheuniverse",
|
||||
},
|
||||
{
|
||||
name: "bool missing use default",
|
||||
flag: newFlag("test", false),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "bool missing use default true",
|
||||
flag: newFlag("test", true),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "int missing use default",
|
||||
flag: newFlag("test", 65),
|
||||
expected: int32(65),
|
||||
},
|
||||
{
|
||||
name: "float missing use default",
|
||||
flag: newFlag("test", 65.65),
|
||||
expected: 65.65,
|
||||
},
|
||||
{
|
||||
name: "string missing use default",
|
||||
flag: newFlag("test", "mydefault"),
|
||||
expected: "mydefault",
|
||||
},
|
||||
|
||||
{
|
||||
name: "bool invalid use default",
|
||||
flag: newFlag("test", true),
|
||||
values: map[string]interface{}{
|
||||
"test": "notabool",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "int invalid use default",
|
||||
flag: newFlag("test", 42),
|
||||
values: map[string]interface{}{
|
||||
"test": 99.99,
|
||||
},
|
||||
expected: int32(42),
|
||||
},
|
||||
{
|
||||
name: "float invalid use default",
|
||||
flag: newFlag("test", 42.42),
|
||||
values: map[string]interface{}{
|
||||
"test": 99,
|
||||
},
|
||||
expected: 42.42,
|
||||
},
|
||||
{
|
||||
name: "string invalid use default",
|
||||
flag: newFlag("test", "restaurantattheendoftheuniverse"),
|
||||
values: map[string]interface{}{
|
||||
"test": true,
|
||||
},
|
||||
expected: "restaurantattheendoftheuniverse",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
t.Run("flagger "+test.name, func(t *testing.T) {
|
||||
flagger := testFlagsFlagger{
|
||||
m: test.values,
|
||||
err: test.err,
|
||||
}
|
||||
|
||||
var actual interface{}
|
||||
switch flag := test.flag.(type) {
|
||||
case feature.BoolFlag:
|
||||
actual = flag.Enabled(test.ctx, flagger)
|
||||
case feature.FloatFlag:
|
||||
actual = flag.Float(test.ctx, flagger)
|
||||
case feature.IntFlag:
|
||||
actual = flag.Int(test.ctx, flagger)
|
||||
case feature.StringFlag:
|
||||
actual = flag.String(test.ctx, flagger)
|
||||
default:
|
||||
t.Errorf("unknown flag type %T (%#v)", flag, flag)
|
||||
}
|
||||
|
||||
if actual != test.expected {
|
||||
t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("annotate "+test.name, func(t *testing.T) {
|
||||
flagger := testFlagsFlagger{
|
||||
m: test.values,
|
||||
err: test.err,
|
||||
}
|
||||
|
||||
ctx, err := feature.Annotate(context.Background(), flagger)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var actual interface{}
|
||||
switch flag := test.flag.(type) {
|
||||
case feature.BoolFlag:
|
||||
actual = flag.Enabled(ctx)
|
||||
case feature.FloatFlag:
|
||||
actual = flag.Float(ctx)
|
||||
case feature.IntFlag:
|
||||
actual = flag.Int(ctx)
|
||||
case feature.StringFlag:
|
||||
actual = flag.String(ctx)
|
||||
default:
|
||||
t.Errorf("unknown flag type %T (%#v)", flag, flag)
|
||||
}
|
||||
|
||||
if actual != test.expected {
|
||||
t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testFlagsFlagger struct {
|
||||
m map[string]interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (f testFlagsFlagger) Flags(ctx context.Context, flags ...feature.Flag) (map[string]interface{}, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
|
||||
return f.m, nil
|
||||
}
|
||||
|
||||
func newFlag(key string, defaultValue interface{}) feature.Flag {
|
||||
return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false)
|
||||
}
|
|
@ -0,0 +1,216 @@
|
|||
//go:generate go run ./_codegen/main.go --in ../../flags.yml --out ./list.go
|
||||
|
||||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Flag represents a generic feature flag with a key and a default.
|
||||
type Flag interface {
|
||||
// Key returns the programmatic backend identifier for the flag.
|
||||
Key() string
|
||||
// Default returns the type-agnostic zero value for the flag.
|
||||
// Type-specific flag implementations may expose a typed default
|
||||
// (e.g. BoolFlag includes a boolean Default field).
|
||||
Default() interface{}
|
||||
// Expose the flag.
|
||||
Expose() bool
|
||||
}
|
||||
|
||||
// MakeFlag constructs a Flag. The concrete implementation is inferred from the provided default.
|
||||
func MakeFlag(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Flag {
|
||||
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
|
||||
switch v := defaultValue.(type) {
|
||||
case bool:
|
||||
return BoolFlag{b, v}
|
||||
case float64:
|
||||
return FloatFlag{b, v}
|
||||
case int32:
|
||||
return IntFlag{b, v}
|
||||
case int:
|
||||
return IntFlag{b, int32(v)}
|
||||
case string:
|
||||
return StringFlag{b, v}
|
||||
default:
|
||||
return StringFlag{b, fmt.Sprintf("%v", v)}
|
||||
}
|
||||
}
|
||||
|
||||
// flag base type.
|
||||
type Base struct {
|
||||
// name of the flag.
|
||||
name string
|
||||
// key is the programmatic backend identifier for the flag.
|
||||
key string
|
||||
// defaultValue for the flag.
|
||||
defaultValue interface{}
|
||||
// owner is an individual or team responsible for the flag.
|
||||
owner string
|
||||
// lifetime of the feature flag.
|
||||
lifetime Lifetime
|
||||
// expose the flag.
|
||||
expose bool
|
||||
}
|
||||
|
||||
var _ Flag = Base{}
|
||||
|
||||
// MakeBase constructs a flag flag.
|
||||
func MakeBase(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Base {
|
||||
return Base{
|
||||
name: name,
|
||||
key: key,
|
||||
owner: owner,
|
||||
defaultValue: defaultValue,
|
||||
lifetime: lifetime,
|
||||
expose: expose,
|
||||
}
|
||||
}
|
||||
|
||||
// Key returns the programmatic backend identifier for the flag.
|
||||
func (f Base) Key() string {
|
||||
return f.key
|
||||
}
|
||||
|
||||
// Default returns the type-agnostic zero value for the flag.
|
||||
func (f Base) Default() interface{} {
|
||||
return f.defaultValue
|
||||
}
|
||||
|
||||
// Expose the flag.
|
||||
func (f Base) Expose() bool {
|
||||
return f.expose
|
||||
}
|
||||
|
||||
func (f Base) value(ctx context.Context, flagger ...Flagger) (interface{}, bool) {
|
||||
var (
|
||||
m map[string]interface{}
|
||||
ok bool
|
||||
)
|
||||
if len(flagger) < 1 {
|
||||
m, ok = ctx.Value(featureContextKey).(map[string]interface{})
|
||||
} else {
|
||||
var err error
|
||||
m, err = flagger[0].Flags(ctx, f)
|
||||
ok = err == nil
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
v, ok := m[f.Key()]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return v, true
|
||||
}
|
||||
|
||||
// StringFlag implements Flag for string values.
|
||||
type StringFlag struct {
|
||||
Base
|
||||
defaultString string
|
||||
}
|
||||
|
||||
var _ Flag = StringFlag{}
|
||||
|
||||
// MakeStringFlag returns a string flag with the given Base and default.
|
||||
func MakeStringFlag(name, key, owner string, defaultValue string, lifetime Lifetime, expose bool) StringFlag {
|
||||
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
|
||||
return StringFlag{b, defaultValue}
|
||||
}
|
||||
|
||||
// String value of the flag on the request context.
|
||||
func (f StringFlag) String(ctx context.Context, flagger ...Flagger) string {
|
||||
i, ok := f.value(ctx, flagger...)
|
||||
if !ok {
|
||||
return f.defaultString
|
||||
}
|
||||
s, ok := i.(string)
|
||||
if !ok {
|
||||
return f.defaultString
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// FloatFlag implements Flag for float values.
|
||||
type FloatFlag struct {
|
||||
Base
|
||||
defaultFloat float64
|
||||
}
|
||||
|
||||
var _ Flag = FloatFlag{}
|
||||
|
||||
// MakeFloatFlag returns a string flag with the given Base and default.
|
||||
func MakeFloatFlag(name, key, owner string, defaultValue float64, lifetime Lifetime, expose bool) FloatFlag {
|
||||
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
|
||||
return FloatFlag{b, defaultValue}
|
||||
}
|
||||
|
||||
// Float value of the flag on the request context.
|
||||
func (f FloatFlag) Float(ctx context.Context, flagger ...Flagger) float64 {
|
||||
i, ok := f.value(ctx, flagger...)
|
||||
if !ok {
|
||||
return f.defaultFloat
|
||||
}
|
||||
v, ok := i.(float64)
|
||||
if !ok {
|
||||
return f.defaultFloat
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// IntFlag implements Flag for integer values.
|
||||
type IntFlag struct {
|
||||
Base
|
||||
defaultInt int32
|
||||
}
|
||||
|
||||
var _ Flag = IntFlag{}
|
||||
|
||||
// MakeIntFlag returns a string flag with the given Base and default.
|
||||
func MakeIntFlag(name, key, owner string, defaultValue int32, lifetime Lifetime, expose bool) IntFlag {
|
||||
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
|
||||
return IntFlag{b, defaultValue}
|
||||
}
|
||||
|
||||
// Int value of the flag on the request context.
|
||||
func (f IntFlag) Int(ctx context.Context, flagger ...Flagger) int32 {
|
||||
i, ok := f.value(ctx, flagger...)
|
||||
if !ok {
|
||||
return f.defaultInt
|
||||
}
|
||||
v, ok := i.(int32)
|
||||
if !ok {
|
||||
return f.defaultInt
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BoolFlag implements Flag for boolean values.
|
||||
type BoolFlag struct {
|
||||
Base
|
||||
defaultBool bool
|
||||
}
|
||||
|
||||
var _ Flag = BoolFlag{}
|
||||
|
||||
// MakeBoolFlag returns a string flag with the given Base and default.
|
||||
func MakeBoolFlag(name, key, owner string, defaultValue bool, lifetime Lifetime, expose bool) BoolFlag {
|
||||
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
|
||||
return BoolFlag{b, defaultValue}
|
||||
}
|
||||
|
||||
// Enabled indicates whether flag is true or false on the request context.
|
||||
func (f BoolFlag) Enabled(ctx context.Context, flagger ...Flagger) bool {
|
||||
i, ok := f.value(ctx, flagger...)
|
||||
if !ok {
|
||||
return f.defaultBool
|
||||
}
|
||||
v, ok := i.(bool)
|
||||
if !ok {
|
||||
return f.defaultBool
|
||||
}
|
||||
return v
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPProxy is an HTTP proxy that's guided by a feature flag. If the feature flag
|
||||
// presented to it is enabled, it will perform the proxying behavior. Otherwise
|
||||
// it will be a no-op.
|
||||
type HTTPProxy struct {
|
||||
proxy *httputil.ReverseProxy
|
||||
logger *zap.Logger
|
||||
enabler ProxyEnabler
|
||||
}
|
||||
|
||||
// NewHTTPProxy returns a new Proxy.
|
||||
func NewHTTPProxy(dest *url.URL, logger *zap.Logger, enabler ProxyEnabler) *HTTPProxy {
|
||||
return &HTTPProxy{
|
||||
proxy: newReverseProxy(dest, enabler.Key()),
|
||||
logger: logger,
|
||||
enabler: enabler,
|
||||
}
|
||||
}
|
||||
|
||||
// Do performs the proxying. It returns whether or not the request was proxied.
|
||||
func (p *HTTPProxy) Do(w http.ResponseWriter, r *http.Request) bool {
|
||||
if p.enabler.Enabled(r.Context()) {
|
||||
p.proxy.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
// headerProxyFlag is the HTTP header for enriching the request and response
|
||||
// with the feature flag key that precipitated the proxying behavior.
|
||||
headerProxyFlag = "X-Platform-Proxy-Flag"
|
||||
)
|
||||
|
||||
// newReverseProxy creates a new single-host reverse proxy.
|
||||
func newReverseProxy(dest *url.URL, enablerKey string) *httputil.ReverseProxy {
|
||||
proxy := httputil.NewSingleHostReverseProxy(dest)
|
||||
|
||||
defaultDirector := proxy.Director
|
||||
proxy.Director = func(r *http.Request) {
|
||||
defaultDirector(r)
|
||||
|
||||
r.Header.Set(headerProxyFlag, enablerKey)
|
||||
|
||||
// Override r.Host to prevent us sending this request back to ourselves.
|
||||
// A bug in the stdlib causes this value to be preferred over the
|
||||
// r.URL.Host (which is set in the default Director) if r.Host isn't
|
||||
// empty (which it isn't).
|
||||
// https://github.com/golang/go/issues/28168
|
||||
r.Host = dest.Host
|
||||
}
|
||||
proxy.ModifyResponse = func(r *http.Response) error {
|
||||
r.Header.Set(headerProxyFlag, enablerKey)
|
||||
return nil
|
||||
}
|
||||
return proxy
|
||||
}
|
||||
|
||||
// ProxyEnabler is a boolean feature flag.
|
||||
type ProxyEnabler interface {
|
||||
Key() string
|
||||
Enabled(ctx context.Context, fs ...Flagger) bool
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
const (
|
||||
destBody = "hello from destination"
|
||||
srcBody = "hello from source"
|
||||
flagKey = "fancy-feature"
|
||||
)
|
||||
|
||||
func TestHTTPProxy_Proxying(t *testing.T) {
|
||||
en := enabler{key: flagKey, state: true}
|
||||
logger := zaptest.NewLogger(t)
|
||||
resp, err := testHTTPProxy(logger, en)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag")
|
||||
if proxyFlag != flagKey {
|
||||
t.Error("X-Platform-Proxy-Flag header not populated")
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
if bodyStr != destBody {
|
||||
t.Errorf("expected body of destination handler, but got: %q", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProxy_DefaultBehavior(t *testing.T) {
|
||||
en := enabler{key: flagKey, state: false}
|
||||
logger := zaptest.NewLogger(t)
|
||||
resp, err := testHTTPProxy(logger, en)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag")
|
||||
if proxyFlag != "" {
|
||||
t.Error("X-Platform-Proxy-Flag header populated")
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
if bodyStr != srcBody {
|
||||
t.Errorf("expected body of source handler, but got: %q", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProxy_RequestHeader(t *testing.T) {
|
||||
h := func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyFlag := r.Header.Get("X-Platform-Proxy-Flag")
|
||||
if proxyFlag != flagKey {
|
||||
t.Error("expected X-Proxy-Flag to contain feature flag key")
|
||||
}
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(h))
|
||||
defer s.Close()
|
||||
|
||||
sURL, err := url.Parse(s.URL)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
en := enabler{key: flagKey, state: true}
|
||||
proxy := NewHTTPProxy(sURL, logger, en)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
srcHandler(proxy)(w, r)
|
||||
}
|
||||
|
||||
func testHTTPProxy(logger *zap.Logger, enabler ProxyEnabler) (*http.Response, error) {
|
||||
s := httptest.NewServer(http.HandlerFunc(destHandler))
|
||||
defer s.Close()
|
||||
|
||||
sURL, err := url.Parse(s.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxy := NewHTTPProxy(sURL, logger, enabler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
srcHandler(proxy)(w, r)
|
||||
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
func destHandler(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, destBody)
|
||||
}
|
||||
|
||||
func srcHandler(proxy *HTTPProxy) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if proxy.Do(w, r) {
|
||||
return
|
||||
}
|
||||
fmt.Fprint(w, srcBody)
|
||||
}
|
||||
}
|
||||
|
||||
type enabler struct {
|
||||
key string
|
||||
state bool
|
||||
}
|
||||
|
||||
func (e enabler) Key() string {
|
||||
return e.key
|
||||
}
|
||||
|
||||
func (e enabler) Enabled(context.Context, ...Flagger) bool {
|
||||
return e.state
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
// Code generated by the feature package; DO NOT EDIT.
|
||||
|
||||
package feature
|
||||
|
||||
var appMetrics = MakeBoolFlag(
|
||||
"App Metrics",
|
||||
"appMetrics",
|
||||
"Bucky, Monitoring Team",
|
||||
false,
|
||||
Permanent,
|
||||
true,
|
||||
)
|
||||
|
||||
// AppMetrics - Send UI Telementry to Tools cluster - should always be false in OSS
|
||||
func AppMetrics() BoolFlag {
|
||||
return appMetrics
|
||||
}
|
||||
|
||||
var backendExample = MakeBoolFlag(
|
||||
"Backend Example",
|
||||
"backendExample",
|
||||
"Gavin Cabbage",
|
||||
false,
|
||||
Permanent,
|
||||
false,
|
||||
)
|
||||
|
||||
// BackendExample - A permanent backend example boolean flag
|
||||
func BackendExample() BoolFlag {
|
||||
return backendExample
|
||||
}
|
||||
|
||||
var communityTemplates = MakeBoolFlag(
|
||||
"Community Templates",
|
||||
"communityTemplates",
|
||||
"Bucky",
|
||||
true,
|
||||
Permanent,
|
||||
true,
|
||||
)
|
||||
|
||||
// CommunityTemplates - Replace current template uploading functionality with community driven templates
|
||||
func CommunityTemplates() BoolFlag {
|
||||
return communityTemplates
|
||||
}
|
||||
|
||||
var frontendExample = MakeIntFlag(
|
||||
"Frontend Example",
|
||||
"frontendExample",
|
||||
"Gavin Cabbage",
|
||||
42,
|
||||
Temporary,
|
||||
true,
|
||||
)
|
||||
|
||||
// FrontendExample - A temporary frontend example integer flag
|
||||
func FrontendExample() IntFlag {
|
||||
return frontendExample
|
||||
}
|
||||
|
||||
var groupWindowAggregateTranspose = MakeBoolFlag(
|
||||
"Group Window Aggregate Transpose",
|
||||
"groupWindowAggregateTranspose",
|
||||
"Query Team",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// GroupWindowAggregateTranspose - Enables the GroupWindowAggregateTransposeRule for all enabled window aggregates
|
||||
func GroupWindowAggregateTranspose() BoolFlag {
|
||||
return groupWindowAggregateTranspose
|
||||
}
|
||||
|
||||
var newLabels = MakeBoolFlag(
|
||||
"New Label Package",
|
||||
"newLabels",
|
||||
"Alirie Gray",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// NewLabelPackage - Enables the refactored labels api
|
||||
func NewLabelPackage() BoolFlag {
|
||||
return newLabels
|
||||
}
|
||||
|
||||
var memoryOptimizedFill = MakeBoolFlag(
|
||||
"Memory Optimized Fill",
|
||||
"memoryOptimizedFill",
|
||||
"Query Team",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// MemoryOptimizedFill - Enable the memory optimized fill()
|
||||
func MemoryOptimizedFill() BoolFlag {
|
||||
return memoryOptimizedFill
|
||||
}
|
||||
|
||||
var memoryOptimizedSchemaMutation = MakeBoolFlag(
|
||||
"Memory Optimized Schema Mutation",
|
||||
"memoryOptimizedSchemaMutation",
|
||||
"Query Team",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// MemoryOptimizedSchemaMutation - Enable the memory optimized schema mutation functions
|
||||
func MemoryOptimizedSchemaMutation() BoolFlag {
|
||||
return memoryOptimizedSchemaMutation
|
||||
}
|
||||
|
||||
var queryTracing = MakeBoolFlag(
|
||||
"Query Tracing",
|
||||
"queryTracing",
|
||||
"Query Team",
|
||||
false,
|
||||
Permanent,
|
||||
false,
|
||||
)
|
||||
|
||||
// QueryTracing - Turn on query tracing for queries that are sampled
|
||||
func QueryTracing() BoolFlag {
|
||||
return queryTracing
|
||||
}
|
||||
|
||||
var bandPlotType = MakeBoolFlag(
|
||||
"Band Plot Type",
|
||||
"bandPlotType",
|
||||
"Monitoring Team",
|
||||
false,
|
||||
Temporary,
|
||||
true,
|
||||
)
|
||||
|
||||
// BandPlotType - Enables the creation of a band plot in Dashboards
|
||||
func BandPlotType() BoolFlag {
|
||||
return bandPlotType
|
||||
}
|
||||
|
||||
var mosaicGraphType = MakeBoolFlag(
|
||||
"Mosaic Graph Type",
|
||||
"mosaicGraphType",
|
||||
"Monitoring Team",
|
||||
false,
|
||||
Temporary,
|
||||
true,
|
||||
)
|
||||
|
||||
// MosaicGraphType - Enables the creation of a mosaic graph in Dashboards
|
||||
func MosaicGraphType() BoolFlag {
|
||||
return mosaicGraphType
|
||||
}
|
||||
|
||||
var notebooks = MakeBoolFlag(
|
||||
"Notebooks",
|
||||
"notebooks",
|
||||
"Monitoring Team",
|
||||
false,
|
||||
Temporary,
|
||||
true,
|
||||
)
|
||||
|
||||
// Notebooks - Determine if the notebook feature's route and navbar icon are visible to the user
|
||||
func Notebooks() BoolFlag {
|
||||
return notebooks
|
||||
}
|
||||
|
||||
var injectLatestSuccessTime = MakeBoolFlag(
|
||||
"Inject Latest Success Time",
|
||||
"injectLatestSuccessTime",
|
||||
"Compute Team",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// InjectLatestSuccessTime - Inject the latest successful task run timestamp into a Task query extern when executing.
|
||||
func InjectLatestSuccessTime() BoolFlag {
|
||||
return injectLatestSuccessTime
|
||||
}
|
||||
|
||||
var enforceOrgDashboardLimits = MakeBoolFlag(
|
||||
"Enforce Organization Dashboard Limits",
|
||||
"enforceOrgDashboardLimits",
|
||||
"Compute Team",
|
||||
false,
|
||||
Temporary,
|
||||
false,
|
||||
)
|
||||
|
||||
// EnforceOrganizationDashboardLimits - Enforces the default limit params for the dashboards api when orgs are set
|
||||
func EnforceOrganizationDashboardLimits() BoolFlag {
|
||||
return enforceOrgDashboardLimits
|
||||
}
|
||||
|
||||
var timeFilterFlags = MakeBoolFlag(
|
||||
"Time Filter Flags",
|
||||
"timeFilterFlags",
|
||||
"Compute Team",
|
||||
false,
|
||||
Temporary,
|
||||
true,
|
||||
)
|
||||
|
||||
// TimeFilterFlags - Filter task run list based on before and after flags
|
||||
func TimeFilterFlags() BoolFlag {
|
||||
return timeFilterFlags
|
||||
}
|
||||
|
||||
var all = []Flag{
|
||||
appMetrics,
|
||||
backendExample,
|
||||
communityTemplates,
|
||||
frontendExample,
|
||||
groupWindowAggregateTranspose,
|
||||
newLabels,
|
||||
memoryOptimizedFill,
|
||||
memoryOptimizedSchemaMutation,
|
||||
queryTracing,
|
||||
bandPlotType,
|
||||
mosaicGraphType,
|
||||
notebooks,
|
||||
injectLatestSuccessTime,
|
||||
enforceOrgDashboardLimits,
|
||||
timeFilterFlags,
|
||||
}
|
||||
|
||||
var byKey = map[string]Flag{
|
||||
"appMetrics": appMetrics,
|
||||
"backendExample": backendExample,
|
||||
"communityTemplates": communityTemplates,
|
||||
"frontendExample": frontendExample,
|
||||
"groupWindowAggregateTranspose": groupWindowAggregateTranspose,
|
||||
"newLabels": newLabels,
|
||||
"memoryOptimizedFill": memoryOptimizedFill,
|
||||
"memoryOptimizedSchemaMutation": memoryOptimizedSchemaMutation,
|
||||
"queryTracing": queryTracing,
|
||||
"bandPlotType": bandPlotType,
|
||||
"mosaicGraphType": mosaicGraphType,
|
||||
"notebooks": notebooks,
|
||||
"injectLatestSuccessTime": injectLatestSuccessTime,
|
||||
"enforceOrgDashboardLimits": enforceOrgDashboardLimits,
|
||||
"timeFilterFlags": timeFilterFlags,
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Handler is a middleware that annotates the context with a map of computed feature flags.
|
||||
// To accurately compute identity-scoped flags, this middleware should be executed after any
|
||||
// authorization middleware has annotated the request context with an authorizer.
|
||||
type Handler struct {
|
||||
log *zap.Logger
|
||||
next http.Handler
|
||||
flagger Flagger
|
||||
flags []Flag
|
||||
}
|
||||
|
||||
// NewHandler returns a configured feature flag middleware that will annotate request context
|
||||
// with a computed map of the given flags using the provided Flagger.
|
||||
func NewHandler(log *zap.Logger, flagger Flagger, flags []Flag, next http.Handler) http.Handler {
|
||||
return &Handler{
|
||||
log: log,
|
||||
next: next,
|
||||
flagger: flagger,
|
||||
flags: flags,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP annotates the request context with a map of computed feature flags before
|
||||
// continuing to serve the request.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, err := Annotate(r.Context(), h.flagger, h.flags...)
|
||||
if err != nil {
|
||||
h.log.Warn("Unable to annotate context with feature flags", zap.Error(err))
|
||||
} else {
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
if h.next != nil {
|
||||
h.next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPErrorHandler is an influxdb.HTTPErrorHandler. It's defined here instead
|
||||
// of referencing the other interface type, because we want to try our best to
|
||||
// avoid cyclical dependencies when feature package is used throughout the
|
||||
// codebase.
|
||||
type HTTPErrorHandler interface {
|
||||
HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter)
|
||||
}
|
||||
|
||||
// NewFlagsHandler returns a handler that returns the map of computed feature flags on the request context.
|
||||
func NewFlagsHandler(errorHandler HTTPErrorHandler, byKey ByKeyFn) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
var (
|
||||
ctx = r.Context()
|
||||
flags = ExposedFlagsFromContext(ctx, byKey)
|
||||
)
|
||||
if err := json.NewEncoder(w).Encode(flags); err != nil {
|
||||
errorHandler.HandleHTTPError(ctx, err, w)
|
||||
}
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package feature_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func Test_Handler(t *testing.T) {
|
||||
var (
|
||||
w = &httptest.ResponseRecorder{}
|
||||
r = httptest.NewRequest(http.MethodGet, "http://nowhere.test", new(bytes.Buffer)).
|
||||
WithContext(context.Background())
|
||||
|
||||
original = r.Context()
|
||||
)
|
||||
|
||||
handler := &checkHandler{t: t, f: func(t *testing.T, r *http.Request) {
|
||||
if r.Context() == original {
|
||||
t.Error("expected annotated context")
|
||||
}
|
||||
}}
|
||||
|
||||
subject := feature.NewHandler(zaptest.NewLogger(t), feature.DefaultFlagger(), feature.Flags(), handler)
|
||||
|
||||
subject.ServeHTTP(w, r)
|
||||
|
||||
if !handler.called {
|
||||
t.Error("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
type checkHandler struct {
|
||||
t *testing.T
|
||||
f func(t *testing.T, r *http.Request)
|
||||
called bool
|
||||
}
|
||||
|
||||
func (h *checkHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) {
|
||||
h.called = true
|
||||
h.f(h.t, r)
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package override
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
)
|
||||
|
||||
// Flagger can override default flag values.
|
||||
type Flagger struct {
|
||||
overrides map[string]string
|
||||
byKey feature.ByKeyFn
|
||||
}
|
||||
|
||||
// Make a Flagger that returns defaults with any overrides parsed from the string.
|
||||
func Make(overrides map[string]string, byKey feature.ByKeyFn) (Flagger, error) {
|
||||
if byKey == nil {
|
||||
byKey = feature.ByKey
|
||||
}
|
||||
|
||||
// Check all provided override keys correspond to an existing Flag.
|
||||
var missing []string
|
||||
for k := range overrides {
|
||||
if _, found := byKey(k); !found {
|
||||
missing = append(missing, k)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return Flagger{}, fmt.Errorf("configured overrides for non-existent flags: %s", strings.Join(missing, ","))
|
||||
}
|
||||
|
||||
return Flagger{
|
||||
overrides: overrides,
|
||||
byKey: byKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Flags returns a map of default values with overrides applied. It never returns an error.
|
||||
func (f Flagger) Flags(_ context.Context, flags ...feature.Flag) (map[string]interface{}, error) {
|
||||
if len(flags) == 0 {
|
||||
flags = feature.Flags()
|
||||
}
|
||||
|
||||
m := make(map[string]interface{}, len(flags))
|
||||
for _, flag := range flags {
|
||||
if s, overridden := f.overrides[flag.Key()]; overridden {
|
||||
iface, err := f.coerce(s, flag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[flag.Key()] = iface
|
||||
} else {
|
||||
m[flag.Key()] = flag.Default()
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (f Flagger) coerce(s string, flag feature.Flag) (iface interface{}, err error) {
|
||||
if base, ok := flag.(feature.Base); ok {
|
||||
flag, _ = f.byKey(base.Key())
|
||||
}
|
||||
|
||||
switch flag.(type) {
|
||||
case feature.BoolFlag:
|
||||
iface, err = strconv.ParseBool(s)
|
||||
case feature.IntFlag:
|
||||
iface, err = strconv.Atoi(s)
|
||||
case feature.FloatFlag:
|
||||
iface, err = strconv.ParseFloat(s, 64)
|
||||
default:
|
||||
iface = s
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("coercing string %q based on flag type %T: %v", s, flag, err)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
package override
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
)
|
||||
|
||||
func TestFlagger(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
env map[string]string
|
||||
defaults []feature.Flag
|
||||
expected map[string]interface{}
|
||||
expectMakeErr bool
|
||||
expectFlagsErr bool
|
||||
byKey feature.ByKeyFn
|
||||
}{
|
||||
{
|
||||
name: "enabled happy path filtering",
|
||||
env: map[string]string{
|
||||
"flag1": "new1",
|
||||
"flag3": "new3",
|
||||
},
|
||||
defaults: []feature.Flag{
|
||||
newFlag("flag0", "original0"),
|
||||
newFlag("flag1", "original1"),
|
||||
newFlag("flag2", "original2"),
|
||||
newFlag("flag3", "original3"),
|
||||
newFlag("flag4", "original4"),
|
||||
},
|
||||
byKey: newByKey(map[string]feature.Flag{
|
||||
"flag0": newFlag("flag0", "original0"),
|
||||
"flag1": newFlag("flag1", "original1"),
|
||||
"flag2": newFlag("flag2", "original2"),
|
||||
"flag3": newFlag("flag3", "original3"),
|
||||
"flag4": newFlag("flag4", "original4"),
|
||||
}),
|
||||
expected: map[string]interface{}{
|
||||
"flag0": "original0",
|
||||
"flag1": "new1",
|
||||
"flag2": "original2",
|
||||
"flag3": "new3",
|
||||
"flag4": "original4",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled happy path types",
|
||||
env: map[string]string{
|
||||
"intflag": "43",
|
||||
"floatflag": "43.43",
|
||||
"boolflag": "true",
|
||||
},
|
||||
defaults: []feature.Flag{
|
||||
newFlag("intflag", 42),
|
||||
newFlag("floatflag", 42.42),
|
||||
newFlag("boolflag", false),
|
||||
},
|
||||
byKey: newByKey(map[string]feature.Flag{
|
||||
"intflag": newFlag("intflag", 42),
|
||||
"floatflag": newFlag("floatflag", 43.43),
|
||||
"boolflag": newFlag("boolflag", false),
|
||||
}),
|
||||
expected: map[string]interface{}{
|
||||
"intflag": 43,
|
||||
"floatflag": 43.43,
|
||||
"boolflag": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "type coerce error",
|
||||
env: map[string]string{
|
||||
"key": "not_an_int",
|
||||
},
|
||||
defaults: []feature.Flag{
|
||||
newFlag("key", 42),
|
||||
},
|
||||
byKey: newByKey(map[string]feature.Flag{
|
||||
"key": newFlag("key", 42),
|
||||
}),
|
||||
expectFlagsErr: true,
|
||||
},
|
||||
{
|
||||
name: "typed base flags",
|
||||
env: map[string]string{
|
||||
"flag1": "411",
|
||||
"flag2": "new2",
|
||||
"flag3": "true",
|
||||
},
|
||||
defaults: []feature.Flag{
|
||||
newBaseFlag("flag0", "original0"),
|
||||
newBaseFlag("flag1", 41),
|
||||
newBaseFlag("flag2", "original2"),
|
||||
newBaseFlag("flag3", false),
|
||||
newBaseFlag("flag4", "original4"),
|
||||
},
|
||||
byKey: newByKey(map[string]feature.Flag{
|
||||
"flag0": newFlag("flag0", "original0"),
|
||||
"flag1": newFlag("flag1", 41),
|
||||
"flag2": newFlag("flag2", "original2"),
|
||||
"flag3": newFlag("flag3", false),
|
||||
"flag4": newFlag("flag4", "original4"),
|
||||
}),
|
||||
expected: map[string]interface{}{
|
||||
"flag0": "original0",
|
||||
"flag1": 411,
|
||||
"flag2": "new2",
|
||||
"flag3": true,
|
||||
"flag4": "original4",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "override for non-existent flag",
|
||||
env: map[string]string{
|
||||
"dne": "foobar",
|
||||
},
|
||||
defaults: []feature.Flag{
|
||||
newBaseFlag("key", "value"),
|
||||
},
|
||||
byKey: newByKey(map[string]feature.Flag{
|
||||
"key": newFlag("key", "value"),
|
||||
}),
|
||||
expectMakeErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
subject, err := Make(test.env, test.byKey)
|
||||
if err != nil {
|
||||
if test.expectMakeErr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error making Flagger: %v", err)
|
||||
}
|
||||
|
||||
computed, err := subject.Flags(context.Background(), test.defaults...)
|
||||
if err != nil {
|
||||
if test.expectFlagsErr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error calling Flags: %v", err)
|
||||
}
|
||||
|
||||
if len(computed) != len(test.expected) {
|
||||
t.Fatalf("incorrect number of flags computed: expected %d, got %d", len(test.expected), len(computed))
|
||||
}
|
||||
|
||||
// check for extra or incorrect keys
|
||||
for k, v := range computed {
|
||||
if xv, found := test.expected[k]; !found {
|
||||
t.Errorf("unexpected key %s", k)
|
||||
} else if v != xv {
|
||||
t.Errorf("incorrect value for key %s: expected %v [%T], got %v [%T]", k, xv, xv, v, v)
|
||||
}
|
||||
}
|
||||
|
||||
// check for missing keys
|
||||
for k := range test.expected {
|
||||
if _, found := computed[k]; !found {
|
||||
t.Errorf("missing expected key %s", k)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newFlag(key string, defaultValue interface{}) feature.Flag {
|
||||
return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false)
|
||||
}
|
||||
|
||||
func newBaseFlag(key string, defaultValue interface{}) feature.Base {
|
||||
return feature.MakeBase(key, key, "", defaultValue, feature.Temporary, false)
|
||||
}
|
||||
|
||||
func newByKey(m map[string]feature.Flag) feature.ByKeyFn {
|
||||
return func(k string) (feature.Flag, bool) {
|
||||
v, found := m[k]
|
||||
return v, found
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package io
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrReadLimitExceeded = errors.New("read limit exceeded")
|
||||
|
||||
// LimitedReadCloser wraps an io.ReadCloser in limiting behavior using
|
||||
// io.LimitedReader. It allows us to obtain the limit error at the time of close
|
||||
// instead of just when writing.
|
||||
type LimitedReadCloser struct {
|
||||
R io.ReadCloser // underlying reader
|
||||
N int64 // max bytes remaining
|
||||
err error
|
||||
closed bool
|
||||
limitExceeded bool
|
||||
}
|
||||
|
||||
// NewLimitedReadCloser returns a new LimitedReadCloser.
|
||||
func NewLimitedReadCloser(r io.ReadCloser, n int64) *LimitedReadCloser {
|
||||
return &LimitedReadCloser{
|
||||
R: r,
|
||||
N: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LimitedReadCloser) Read(p []byte) (n int, err error) {
|
||||
if l.N <= 0 {
|
||||
l.limitExceeded = true
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(p)) > l.N {
|
||||
p = p[0:l.N]
|
||||
}
|
||||
n, err = l.R.Read(p)
|
||||
l.N -= int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
// Close returns an ErrReadLimitExceeded when the wrapped reader exceeds the set
|
||||
// limit for number of bytes. This is safe to call more than once but not
|
||||
// concurrently.
|
||||
func (l *LimitedReadCloser) Close() (err error) {
|
||||
if l.limitExceeded {
|
||||
l.err = ErrReadLimitExceeded
|
||||
}
|
||||
if l.closed {
|
||||
// Close has already been called.
|
||||
return l.err
|
||||
}
|
||||
if err := l.R.Close(); err != nil && l.err == nil {
|
||||
l.err = err
|
||||
}
|
||||
// Prevent l.closer.Close from being called again.
|
||||
l.closed = true
|
||||
return l.err
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
package io
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitedReadCloser_Exceeded(t *testing.T) {
|
||||
b := &closer{Reader: bytes.NewBufferString("howdy")}
|
||||
rc := NewLimitedReadCloser(b, 3)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("how"), out)
|
||||
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
|
||||
}
|
||||
|
||||
func TestLimitedReadCloser_Happy(t *testing.T) {
|
||||
b := &closer{Reader: bytes.NewBufferString("ho")}
|
||||
rc := NewLimitedReadCloser(b, 2)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("ho"), out)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestLimitedReadCloseWithErrorAndLimitExceeded(t *testing.T) {
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: errors.New("some error"),
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 3)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("how"), out)
|
||||
// LimitExceeded error trumps the close error.
|
||||
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
|
||||
}
|
||||
|
||||
func TestLimitedReadCloseWithError(t *testing.T) {
|
||||
closeErr := errors.New("some error")
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: closeErr,
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 10)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("howdy"), out)
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
}
|
||||
|
||||
func TestMultipleCloseOnlyClosesOnce(t *testing.T) {
|
||||
closeErr := errors.New("some error")
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: closeErr,
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 10)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("howdy"), out)
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
assert.Equal(t, 1, b.closeCount)
|
||||
}
|
||||
|
||||
type closer struct {
|
||||
io.Reader
|
||||
err error
|
||||
closeCount int
|
||||
}
|
||||
|
||||
func (c *closer) Close() error {
|
||||
c.closeCount++
|
||||
return c.err
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package metric
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// REDClient is a metrics client for collection RED metrics.
|
||||
type REDClient struct {
|
||||
metrics []metricCollector
|
||||
}
|
||||
|
||||
// New creates a new REDClient.
|
||||
func New(reg prometheus.Registerer, service string, opts ...ClientOptFn) *REDClient {
|
||||
opt := metricOpts{
|
||||
namespace: "service",
|
||||
service: service,
|
||||
counterMetrics: map[string]VecOpts{
|
||||
"call_total": {
|
||||
Help: "Number of calls",
|
||||
LabelNames: []string{"method"},
|
||||
CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) {
|
||||
vec.With(prometheus.Labels{"method": o.Method}).Inc()
|
||||
},
|
||||
},
|
||||
"error_total": {
|
||||
Help: "Number of errors encountered",
|
||||
LabelNames: []string{"method", "code"},
|
||||
CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) {
|
||||
if o.Err != nil {
|
||||
vec.With(prometheus.Labels{
|
||||
"method": o.Method,
|
||||
"code": errors.ErrorCode(o.Err),
|
||||
}).Inc()
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
histogramMetrics: map[string]VecOpts{
|
||||
"duration": {
|
||||
Help: "Duration of calls",
|
||||
LabelNames: []string{"method"},
|
||||
HistogramFn: func(vec *prometheus.HistogramVec, o CollectFnOpts) {
|
||||
vec.
|
||||
With(prometheus.Labels{"method": o.Method}).
|
||||
Observe(time.Since(o.Start).Seconds())
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
client := new(REDClient)
|
||||
for metricName, vecOpts := range opt.counterMetrics {
|
||||
client.metrics = append(client.metrics, &counter{
|
||||
fn: vecOpts.CounterFn,
|
||||
CounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: opt.namespace,
|
||||
Subsystem: opt.serviceName(),
|
||||
Name: metricName,
|
||||
Help: vecOpts.Help,
|
||||
}, vecOpts.LabelNames),
|
||||
})
|
||||
}
|
||||
|
||||
for metricName, vecOpts := range opt.histogramMetrics {
|
||||
client.metrics = append(client.metrics, &histogram{
|
||||
fn: vecOpts.HistogramFn,
|
||||
HistogramVec: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: opt.namespace,
|
||||
Subsystem: opt.serviceName(),
|
||||
Name: metricName,
|
||||
Help: vecOpts.Help,
|
||||
}, vecOpts.LabelNames),
|
||||
})
|
||||
}
|
||||
|
||||
reg.MustRegister(client.collectors()...)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type RecordFn func(err error, opts ...func(opts *CollectFnOpts)) error
|
||||
|
||||
// RecordAdditional provides an extension to the base method, err data provided
|
||||
// to the metrics.
|
||||
func RecordAdditional(props map[string]interface{}) func(opts *CollectFnOpts) {
|
||||
return func(opts *CollectFnOpts) {
|
||||
opts.AdditionalProps = props
|
||||
}
|
||||
}
|
||||
|
||||
// Record returns a record fn that is called on any given return err. If an error is encountered
|
||||
// it will register the err metric. The err is never altered.
|
||||
func (c *REDClient) Record(method string) RecordFn {
|
||||
start := time.Now()
|
||||
return func(err error, opts ...func(opts *CollectFnOpts)) error {
|
||||
opt := CollectFnOpts{
|
||||
Method: method,
|
||||
Start: start,
|
||||
Err: err,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
for _, metric := range c.metrics {
|
||||
metric.collect(opt)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *REDClient) collectors() []prometheus.Collector {
|
||||
var collectors []prometheus.Collector
|
||||
for _, metric := range c.metrics {
|
||||
collectors = append(collectors, metric)
|
||||
}
|
||||
return collectors
|
||||
}
|
||||
|
||||
type metricCollector interface {
|
||||
prometheus.Collector
|
||||
|
||||
collect(o CollectFnOpts)
|
||||
}
|
||||
|
||||
type counter struct {
|
||||
*prometheus.CounterVec
|
||||
|
||||
fn CounterFn
|
||||
}
|
||||
|
||||
func (c *counter) collect(o CollectFnOpts) {
|
||||
c.fn(c.CounterVec, o)
|
||||
}
|
||||
|
||||
type histogram struct {
|
||||
*prometheus.HistogramVec
|
||||
|
||||
fn HistogramFn
|
||||
}
|
||||
|
||||
func (h *histogram) collect(o CollectFnOpts) {
|
||||
h.fn(h.HistogramVec, o)
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package metric
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type (
|
||||
// CollectFnOpts provides arguments to the collect operation of a metric.
|
||||
CollectFnOpts struct {
|
||||
Method string
|
||||
Start time.Time
|
||||
Err error
|
||||
AdditionalProps map[string]interface{}
|
||||
}
|
||||
|
||||
CounterFn func(vec *prometheus.CounterVec, o CollectFnOpts)
|
||||
|
||||
HistogramFn func(vec *prometheus.HistogramVec, o CollectFnOpts)
|
||||
|
||||
// VecOpts expands on the
|
||||
VecOpts struct {
|
||||
Name string
|
||||
Help string
|
||||
LabelNames []string
|
||||
|
||||
CounterFn CounterFn
|
||||
HistogramFn HistogramFn
|
||||
}
|
||||
)
|
||||
|
||||
type metricOpts struct {
|
||||
namespace string
|
||||
service string
|
||||
serviceSuffix string
|
||||
counterMetrics map[string]VecOpts
|
||||
histogramMetrics map[string]VecOpts
|
||||
}
|
||||
|
||||
func (o metricOpts) serviceName() string {
|
||||
if o.serviceSuffix != "" {
|
||||
return fmt.Sprintf("%s_%s", o.service, o.serviceSuffix)
|
||||
}
|
||||
return o.service
|
||||
}
|
||||
|
||||
// ClientOptFn is an option used by a metric middleware.
|
||||
type ClientOptFn func(*metricOpts)
|
||||
|
||||
// WithVec sets a new counter vector to be collected.
|
||||
func WithVec(opts VecOpts) ClientOptFn {
|
||||
return func(o *metricOpts) {
|
||||
if opts.CounterFn != nil {
|
||||
if o.counterMetrics == nil {
|
||||
o.counterMetrics = make(map[string]VecOpts)
|
||||
}
|
||||
o.counterMetrics[opts.Name] = opts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithSuffix returns a metric option that applies a suffix to the service name of the metric.
|
||||
func WithSuffix(suffix string) ClientOptFn {
|
||||
return func(opts *metricOpts) {
|
||||
opts.serviceSuffix = suffix
|
||||
}
|
||||
}
|
||||
|
||||
func ApplyMetricOpts(opts ...ClientOptFn) *metricOpts {
|
||||
o := metricOpts{}
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
return &o
|
||||
}
|
||||
|
||||
func (o *metricOpts) ApplySuffix(prefix string) string {
|
||||
if o.serviceSuffix != "" {
|
||||
return fmt.Sprintf("%s_%s", prefix, o.serviceSuffix)
|
||||
}
|
||||
return prefix
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package errors
|
||||
|
||||
// ChronografError is a domain error encountered while processing chronograf requests.
|
||||
type ChronografError string
|
||||
|
||||
// ChronografError returns the string of an error.
|
||||
func (e ChronografError) Error() string {
|
||||
return string(e)
|
||||
}
|
|
@ -0,0 +1,264 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Some error code constant, ideally we want define common platform codes here
|
||||
// projects on use platform's error, should have their own central place like this.
|
||||
// Any time this set of constants changes, you must also update the swagger for Error.properties.code.enum.
|
||||
const (
|
||||
EInternal = "internal error"
|
||||
ENotImplemented = "not implemented"
|
||||
ENotFound = "not found"
|
||||
EConflict = "conflict" // action cannot be performed
|
||||
EInvalid = "invalid" // validation failed
|
||||
EUnprocessableEntity = "unprocessable entity" // data type is correct, but out of range
|
||||
EEmptyValue = "empty value"
|
||||
EUnavailable = "unavailable"
|
||||
EForbidden = "forbidden"
|
||||
ETooManyRequests = "too many requests"
|
||||
EUnauthorized = "unauthorized"
|
||||
EMethodNotAllowed = "method not allowed"
|
||||
ETooLarge = "request too large"
|
||||
)
|
||||
|
||||
// Error is the error struct of platform.
|
||||
//
|
||||
// Errors may have error codes, human-readable messages,
|
||||
// and a logical stack trace.
|
||||
//
|
||||
// The Code targets automated handlers so that recovery can occur.
|
||||
// Msg is used by the system operator to help diagnose and fix the problem.
|
||||
// Op and Err chain errors together in a logical stack trace to
|
||||
// further help operators.
|
||||
//
|
||||
// To create a simple error,
|
||||
// &Error{
|
||||
// Code:ENotFound,
|
||||
// }
|
||||
// To show where the error happens, add Op.
|
||||
// &Error{
|
||||
// Code: ENotFound,
|
||||
// Op: "bolt.FindUserByID"
|
||||
// }
|
||||
// To show an error with a unpredictable value, add the value in Msg.
|
||||
// &Error{
|
||||
// Code: EConflict,
|
||||
// Message: fmt.Sprintf("organization with name %s already exist", aName),
|
||||
// }
|
||||
// To show an error wrapped with another error.
|
||||
// &Error{
|
||||
// Code:EInternal,
|
||||
// Err: err,
|
||||
// }.
|
||||
type Error struct {
|
||||
Code string
|
||||
Msg string
|
||||
Op string
|
||||
Err error
|
||||
}
|
||||
|
||||
// NewError returns an instance of an error.
|
||||
func NewError(options ...func(*Error)) *Error {
|
||||
err := &Error{}
|
||||
for _, o := range options {
|
||||
o(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// WithErrorErr sets the err on the error.
|
||||
func WithErrorErr(err error) func(*Error) {
|
||||
return func(e *Error) {
|
||||
e.Err = err
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorCode sets the code on the error.
|
||||
func WithErrorCode(code string) func(*Error) {
|
||||
return func(e *Error) {
|
||||
e.Code = code
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorMsg sets the message on the error.
|
||||
func WithErrorMsg(msg string) func(*Error) {
|
||||
return func(e *Error) {
|
||||
e.Msg = msg
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorOp sets the message on the error.
|
||||
func WithErrorOp(op string) func(*Error) {
|
||||
return func(e *Error) {
|
||||
e.Op = op
|
||||
}
|
||||
}
|
||||
|
||||
// Error implements the error interface by writing out the recursive messages.
|
||||
func (e *Error) Error() string {
|
||||
if e.Msg != "" && e.Err != nil {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Msg)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Err.Error())
|
||||
return b.String()
|
||||
} else if e.Msg != "" {
|
||||
return e.Msg
|
||||
} else if e.Err != nil {
|
||||
return e.Err.Error()
|
||||
}
|
||||
return fmt.Sprintf("<%s>", e.Code)
|
||||
}
|
||||
|
||||
// ErrorCode returns the code of the root error, if available; otherwise returns EINTERNAL.
|
||||
func ErrorCode(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
e, ok := err.(*Error)
|
||||
if !ok {
|
||||
return EInternal
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if e.Code != "" {
|
||||
return e.Code
|
||||
}
|
||||
|
||||
if e.Err != nil {
|
||||
return ErrorCode(e.Err)
|
||||
}
|
||||
|
||||
return EInternal
|
||||
}
|
||||
|
||||
// ErrorOp returns the op of the error, if available; otherwise return empty string.
|
||||
func ErrorOp(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
e, ok := err.(*Error)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if e.Op != "" {
|
||||
return e.Op
|
||||
}
|
||||
|
||||
if e.Err != nil {
|
||||
return ErrorOp(e.Err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ErrorMessage returns the human-readable message of the error, if available.
|
||||
// Otherwise returns a generic error message.
|
||||
func ErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
e, ok := err.(*Error)
|
||||
if !ok {
|
||||
return "An internal error has occurred."
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
|
||||
if e.Err != nil {
|
||||
return ErrorMessage(e.Err)
|
||||
}
|
||||
|
||||
return "An internal error has occurred."
|
||||
}
|
||||
|
||||
// errEncode an JSON encoding helper that is needed to handle the recursive stack of errors.
|
||||
type errEncode struct {
|
||||
Code string `json:"code"` // Code is the machine-readable error code.
|
||||
Msg string `json:"message,omitempty"` // Msg is a human-readable message.
|
||||
Op string `json:"op,omitempty"` // Op describes the logical code operation during error.
|
||||
Err interface{} `json:"error,omitempty"` // Err is a stack of additional errors.
|
||||
}
|
||||
|
||||
// MarshalJSON recursively marshals the stack of Err.
|
||||
func (e *Error) MarshalJSON() (result []byte, err error) {
|
||||
ee := errEncode{
|
||||
Code: e.Code,
|
||||
Msg: e.Msg,
|
||||
Op: e.Op,
|
||||
}
|
||||
if e.Err != nil {
|
||||
if _, ok := e.Err.(*Error); ok {
|
||||
_, err := e.Err.(*Error).MarshalJSON()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
ee.Err = e.Err
|
||||
} else {
|
||||
ee.Err = e.Err.Error()
|
||||
}
|
||||
}
|
||||
return json.Marshal(ee)
|
||||
}
|
||||
|
||||
// UnmarshalJSON recursively unmarshals the error stack.
|
||||
func (e *Error) UnmarshalJSON(b []byte) (err error) {
|
||||
ee := new(errEncode)
|
||||
err = json.Unmarshal(b, ee)
|
||||
e.Code = ee.Code
|
||||
e.Msg = ee.Msg
|
||||
e.Op = ee.Op
|
||||
e.Err = decodeInternalError(ee.Err)
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeInternalError(target interface{}) error {
|
||||
if errStr, ok := target.(string); ok {
|
||||
return errors.New(errStr)
|
||||
}
|
||||
if internalErrMap, ok := target.(map[string]interface{}); ok {
|
||||
internalErr := new(Error)
|
||||
if code, ok := internalErrMap["code"].(string); ok {
|
||||
internalErr.Code = code
|
||||
}
|
||||
if msg, ok := internalErrMap["message"].(string); ok {
|
||||
internalErr.Msg = msg
|
||||
}
|
||||
if op, ok := internalErrMap["op"].(string); ok {
|
||||
internalErr.Op = op
|
||||
}
|
||||
internalErr.Err = decodeInternalError(internalErrMap["error"])
|
||||
return internalErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HTTPErrorHandler is the interface to handle http error.
|
||||
type HTTPErrorHandler interface {
|
||||
HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter)
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
# errors.go
|
||||
|
||||
This is inspired from Ben Johnson's blog post [Failure is Your Domain](https://middlemost.com/failure-is-your-domain/)
|
||||
|
||||
## The Error struct
|
||||
|
||||
```go
|
||||
|
||||
type Error struct {
|
||||
Code string
|
||||
Msg string
|
||||
Op string
|
||||
Err error
|
||||
}
|
||||
```
|
||||
|
||||
* Code is the machine readable code, for reference purpose. All the codes should be a constant string. For example. `const ENotFound = "source not found"`.
|
||||
|
||||
* Msg is the human readable message for end user. For example, `Your credit card is declined.`
|
||||
|
||||
* Op is the logical Operator, should be a constant defined inside the function. For example: "bolt.UserCreate".
|
||||
|
||||
* Err is the embed error. You may embed either a third party error or and platform.Error.
|
||||
|
||||
## Use Case Example
|
||||
|
||||
We implement the following interface
|
||||
|
||||
```go
|
||||
|
||||
type OrganizationService interface {
|
||||
FindOrganizationByID(ctx context.Context, id ID) (*Organization, error)
|
||||
}
|
||||
|
||||
func (c *Client)FindOrganizationByID(ctx context.Context, id platform.ID) (*platform.Organization, error) {
|
||||
var o *platform.Organization
|
||||
const op = "bolt.FindOrganizationByID"
|
||||
err := c.db.View(func(tx *bolt.Tx) error {
|
||||
org, err := c.findOrganizationByID(ctx, tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o = org
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, &platform.Error{
|
||||
Code: platform.ENotFound,
|
||||
Op: op,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
```
|
||||
|
||||
To check the error code
|
||||
|
||||
```go
|
||||
|
||||
if platform.ErrorCode(err) == platform.ENotFound {
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
To serialize the error
|
||||
|
||||
```go
|
||||
|
||||
b, err := json.Marshal(err)
|
||||
```
|
||||
|
||||
To deserialize the error
|
||||
|
||||
```go
|
||||
|
||||
e := new(platform.Error)
|
||||
err := json.Unmarshal(b, e)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
package errors_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
errors2 "github.com/influxdata/influxdb/kit/platform/errors"
|
||||
)
|
||||
|
||||
const EFailedToGetStorageHost = "failed to get the storage host"
|
||||
|
||||
func TestErrorMsg(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
msg string
|
||||
}{
|
||||
{
|
||||
name: "simple error",
|
||||
err: &errors2.Error{Code: errors2.ENotFound},
|
||||
msg: "<not found>",
|
||||
},
|
||||
{
|
||||
name: "with message",
|
||||
err: &errors2.Error{
|
||||
Code: errors2.ENotFound,
|
||||
Msg: "bucket not found",
|
||||
},
|
||||
msg: "bucket not found",
|
||||
},
|
||||
{
|
||||
name: "with a third party error and no message",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Err: errors.New("empty value"),
|
||||
},
|
||||
msg: "empty value",
|
||||
},
|
||||
{
|
||||
name: "with a third party error and a message",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Msg: "failed to get storage hosts",
|
||||
Err: errors.New("empty value"),
|
||||
},
|
||||
msg: "failed to get storage hosts: empty value",
|
||||
},
|
||||
{
|
||||
name: "with an internal error and no message",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Err: &errors2.Error{
|
||||
Code: errors2.EEmptyValue,
|
||||
Msg: "empty value",
|
||||
},
|
||||
},
|
||||
msg: "empty value",
|
||||
},
|
||||
{
|
||||
name: "with an internal error and a message",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Msg: "failed to get storage hosts",
|
||||
Err: &errors2.Error{
|
||||
Code: errors2.EEmptyValue,
|
||||
Msg: "empty value",
|
||||
},
|
||||
},
|
||||
msg: "failed to get storage hosts: empty value",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if c.msg != c.err.Error() {
|
||||
t.Errorf("%s failed, want %s, got %s", c.name, c.msg, c.err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorMessage(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
},
|
||||
{
|
||||
name: "nil error of type *platform.Error",
|
||||
err: (*errors2.Error)(nil),
|
||||
},
|
||||
{
|
||||
name: "simple error",
|
||||
err: &errors2.Error{Msg: "simple error"},
|
||||
want: "simple error",
|
||||
},
|
||||
{
|
||||
name: "embedded error",
|
||||
err: &errors2.Error{Err: &errors2.Error{Msg: "embedded error"}},
|
||||
want: "embedded error",
|
||||
},
|
||||
{
|
||||
name: "default error",
|
||||
err: errors.New("s"),
|
||||
want: "An internal error has occurred.",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if result := errors2.ErrorMessage(c.err); c.want != result {
|
||||
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorOp(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
},
|
||||
{
|
||||
name: "nil error of type *platform.Error",
|
||||
err: (*errors2.Error)(nil),
|
||||
},
|
||||
{
|
||||
name: "simple error",
|
||||
err: &errors2.Error{Op: "op1"},
|
||||
want: "op1",
|
||||
},
|
||||
{
|
||||
name: "embedded error",
|
||||
err: &errors2.Error{Op: "op1", Err: &errors2.Error{Code: errors2.EInvalid}},
|
||||
want: "op1",
|
||||
},
|
||||
{
|
||||
name: "embedded error without op in root level",
|
||||
err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid, Op: "op2"}},
|
||||
want: "op2",
|
||||
},
|
||||
{
|
||||
name: "default error",
|
||||
err: errors.New("s"),
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if result := errors2.ErrorOp(c.err); c.want != result {
|
||||
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestErrorCode(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
},
|
||||
{
|
||||
name: "nil error of type *platform.Error",
|
||||
err: (*errors2.Error)(nil),
|
||||
},
|
||||
{
|
||||
name: "simple error",
|
||||
err: &errors2.Error{Code: errors2.ENotFound},
|
||||
want: errors2.ENotFound,
|
||||
},
|
||||
{
|
||||
name: "embedded error",
|
||||
err: &errors2.Error{Code: errors2.ENotFound, Err: &errors2.Error{Code: errors2.EInvalid}},
|
||||
want: errors2.ENotFound,
|
||||
},
|
||||
{
|
||||
name: "embedded error with root level code",
|
||||
err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid}},
|
||||
want: errors2.EInvalid,
|
||||
},
|
||||
{
|
||||
name: "default error",
|
||||
err: errors.New("s"),
|
||||
want: errors2.EInternal,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if result := errors2.ErrorCode(c.err); c.want != result {
|
||||
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err *errors2.Error
|
||||
encoded string
|
||||
}{
|
||||
{
|
||||
name: "simple error",
|
||||
err: &errors2.Error{Code: errors2.ENotFound},
|
||||
encoded: `{"code":"not found"}`,
|
||||
},
|
||||
{
|
||||
name: "with op",
|
||||
err: &errors2.Error{
|
||||
Code: errors2.ENotFound,
|
||||
Op: "bolt.FindAuthorizationByID",
|
||||
},
|
||||
encoded: `{"code":"not found","op":"bolt.FindAuthorizationByID"}`,
|
||||
},
|
||||
{
|
||||
name: "with op and value",
|
||||
err: &errors2.Error{
|
||||
Code: errors2.ENotFound,
|
||||
Op: "bolt/FindAuthorizationByID",
|
||||
Msg: fmt.Sprintf("with ID %d", 323),
|
||||
},
|
||||
encoded: `{"code":"not found","message":"with ID 323","op":"bolt/FindAuthorizationByID"}`,
|
||||
},
|
||||
{
|
||||
name: "with a third party error",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Op: "cmd/fluxd.injectDeps",
|
||||
Err: errors.New("empty value"),
|
||||
},
|
||||
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":"empty value"}`,
|
||||
},
|
||||
{
|
||||
name: "with a internal error",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Op: "cmd/fluxd.injectDeps",
|
||||
Err: &errors2.Error{Code: errors2.EEmptyValue, Op: "cmd/fluxd.getStrList"},
|
||||
},
|
||||
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"empty value","op":"cmd/fluxd.getStrList"}}`,
|
||||
},
|
||||
{
|
||||
name: "with a deep internal error",
|
||||
err: &errors2.Error{
|
||||
Code: EFailedToGetStorageHost,
|
||||
Op: "cmd/fluxd.injectDeps",
|
||||
Err: &errors2.Error{
|
||||
Code: errors2.EInvalid,
|
||||
Op: "cmd/fluxd.getStrList",
|
||||
Err: &errors2.Error{
|
||||
Code: errors2.EEmptyValue,
|
||||
Err: errors.New("an err"),
|
||||
},
|
||||
},
|
||||
},
|
||||
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"invalid","op":"cmd/fluxd.getStrList","error":{"code":"empty value","error":"an err"}}}`,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
result, err := json.Marshal(c.err)
|
||||
// encode testing
|
||||
if err != nil {
|
||||
t.Errorf("%s encode failed, want err: %v, should be nil", c.name, err)
|
||||
}
|
||||
|
||||
if string(result) != c.encoded {
|
||||
t.Errorf("%s encode failed, want result: %s, got %s", c.name, c.encoded, string(result))
|
||||
}
|
||||
// decode testing
|
||||
got := new(errors2.Error)
|
||||
err = json.Unmarshal(result, got)
|
||||
if err != nil {
|
||||
t.Errorf("%s decode failed, want err: %v, should be nil", c.name, err)
|
||||
}
|
||||
decodeEqual(t, c.err, got, "decode: "+c.name)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeEqual(t *testing.T, want, result *errors2.Error, caseName string) {
|
||||
if want.Code != result.Code {
|
||||
t.Errorf("%s code failed, want %s, got %s", caseName, want.Code, result.Code)
|
||||
}
|
||||
if want.Op != result.Op {
|
||||
t.Errorf("%s op failed, want %s, got %s", caseName, want.Op, result.Op)
|
||||
}
|
||||
if want.Msg != result.Msg {
|
||||
t.Errorf("%s msg failed, want %s, got %s", caseName, want.Msg, result.Msg)
|
||||
}
|
||||
if want.Err != nil {
|
||||
if _, ok := want.Err.(*errors2.Error); ok {
|
||||
decodeEqual(t, want.Err.(*errors2.Error), result.Err.(*errors2.Error), caseName)
|
||||
} else {
|
||||
if want.Err.Error() != result.Err.Error() {
|
||||
t.Errorf("%s Err failed, want %s, got %s", caseName, want.Err.Error(), result.Err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
package platform
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
)
|
||||
|
||||
// IDLength is the exact length a string (or a byte slice representing it) must have in order to be decoded into a valid ID.
|
||||
const IDLength = 16
|
||||
|
||||
var (
|
||||
// ErrInvalidID signifies invalid IDs.
|
||||
ErrInvalidID = &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: "invalid ID",
|
||||
}
|
||||
|
||||
// ErrInvalidIDLength is returned when an ID has the incorrect number of bytes.
|
||||
ErrInvalidIDLength = &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: "id must have a length of 16 bytes",
|
||||
}
|
||||
)
|
||||
|
||||
// ErrCorruptID means the ID stored in the Store is corrupt.
|
||||
func ErrCorruptID(err error) *errors.Error {
|
||||
return &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: "corrupt ID provided",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ID is a unique identifier.
|
||||
//
|
||||
// Its zero value is not a valid ID.
|
||||
type ID uint64
|
||||
|
||||
// IDGenerator represents a generator for IDs.
|
||||
type IDGenerator interface {
|
||||
// ID creates unique byte slice ID.
|
||||
ID() ID
|
||||
}
|
||||
|
||||
// IDFromString creates an ID from a given string.
|
||||
//
|
||||
// It errors if the input string does not match a valid ID.
|
||||
func IDFromString(str string) (*ID, error) {
|
||||
var id ID
|
||||
err := id.DecodeFromString(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &id, nil
|
||||
}
|
||||
|
||||
// InvalidID returns a zero ID.
|
||||
func InvalidID() ID {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Decode parses b as a hex-encoded byte-slice-string.
|
||||
//
|
||||
// It errors if the input byte slice does not have the correct length
|
||||
// or if it contains all zeros.
|
||||
func (i *ID) Decode(b []byte) error {
|
||||
if len(b) != IDLength {
|
||||
return ErrInvalidIDLength
|
||||
}
|
||||
|
||||
res, err := strconv.ParseUint(unsafeBytesToString(b), 16, 64)
|
||||
if err != nil {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
if *i = ID(res); !i.Valid() {
|
||||
return ErrInvalidID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unsafeBytesToString(in []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&in))
|
||||
}
|
||||
|
||||
// DecodeFromString parses s as a hex-encoded string.
|
||||
func (i *ID) DecodeFromString(s string) error {
|
||||
return i.Decode([]byte(s))
|
||||
}
|
||||
|
||||
// Encode converts ID to a hex-encoded byte-slice-string.
|
||||
//
|
||||
// It errors if the receiving ID holds its zero value.
|
||||
func (i ID) Encode() ([]byte, error) {
|
||||
if !i.Valid() {
|
||||
return nil, ErrInvalidID
|
||||
}
|
||||
|
||||
b := make([]byte, hex.DecodedLen(IDLength))
|
||||
binary.BigEndian.PutUint64(b, uint64(i))
|
||||
|
||||
dst := make([]byte, hex.EncodedLen(len(b)))
|
||||
hex.Encode(dst, b)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// Valid checks whether the receiving ID is a valid one or not.
|
||||
func (i ID) Valid() bool {
|
||||
return i != 0
|
||||
}
|
||||
|
||||
// String returns the ID as a hex encoded string.
|
||||
//
|
||||
// Returns an empty string in the case the ID is invalid.
|
||||
func (i ID) String() string {
|
||||
enc, _ := i.Encode()
|
||||
return string(enc)
|
||||
}
|
||||
|
||||
// GoString formats the ID the same as the String method.
|
||||
// Without this, when using the %#v verb, an ID would be printed as a uint64,
|
||||
// so you would see e.g. 0x2def021097c6000 instead of 02def021097c6000
|
||||
// (note the leading 0x, which means the former doesn't show up in searches for the latter).
|
||||
func (i ID) GoString() string {
|
||||
return `"` + i.String() + `"`
|
||||
}
|
||||
|
||||
// MarshalText encodes i as text.
|
||||
// Providing this method is a fallback for json.Marshal,
|
||||
// with the added benefit that IDs encoded as map keys will be the expected string encoding,
|
||||
// rather than the effective fmt.Sprintf("%d", i) that json.Marshal uses by default for integer types.
|
||||
func (i ID) MarshalText() ([]byte, error) {
|
||||
return i.Encode()
|
||||
}
|
||||
|
||||
// UnmarshalText decodes i from a byte slice.
|
||||
// Providing this method is also a fallback for json.Unmarshal,
|
||||
// also relevant when IDs are used as map keys.
|
||||
func (i *ID) UnmarshalText(b []byte) error {
|
||||
return i.Decode(b)
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
package platform_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform"
|
||||
)
|
||||
|
||||
func MustIDBase16(s string) platform.ID {
|
||||
id, err := platform.IDFromString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return *id
|
||||
}
|
||||
|
||||
func TestIDFromString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
want platform.ID
|
||||
wantErr bool
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "Should be able to decode an all zeros ID",
|
||||
id: "0000000000000000",
|
||||
wantErr: true,
|
||||
err: platform.ErrInvalidID.Error(),
|
||||
},
|
||||
{
|
||||
name: "Should be able to decode an all f ID",
|
||||
id: "ffffffffffffffff",
|
||||
want: MustIDBase16("ffffffffffffffff"),
|
||||
},
|
||||
{
|
||||
name: "Should be able to decode an ID",
|
||||
id: "020f755c3c082000",
|
||||
want: MustIDBase16("020f755c3c082000"),
|
||||
},
|
||||
{
|
||||
name: "Should not be able to decode a non hex ID",
|
||||
id: "gggggggggggggggg",
|
||||
wantErr: true,
|
||||
err: platform.ErrInvalidID.Error(),
|
||||
},
|
||||
{
|
||||
name: "Should not be able to decode inputs with length less than 16 bytes",
|
||||
id: "abc",
|
||||
wantErr: true,
|
||||
err: platform.ErrInvalidIDLength.Error(),
|
||||
},
|
||||
{
|
||||
name: "Should not be able to decode inputs with length greater than 16 bytes",
|
||||
id: "abcdabcdabcdabcd0",
|
||||
wantErr: true,
|
||||
err: platform.ErrInvalidIDLength.Error(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := platform.IDFromString(tt.id)
|
||||
|
||||
// Check negative test cases
|
||||
if (err != nil) && tt.wantErr {
|
||||
if tt.err != err.Error() {
|
||||
t.Errorf("IDFromString() errors out \"%s\", want \"%s\"", err, tt.err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check positive test cases
|
||||
if !reflect.DeepEqual(*got, tt.want) && !tt.wantErr {
|
||||
t.Errorf("IDFromString() outputs %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFromString(t *testing.T) {
|
||||
var id platform.ID
|
||||
err := id.DecodeFromString("020f755c3c082000")
|
||||
if err != nil {
|
||||
t.Errorf(err.Error())
|
||||
}
|
||||
want := []byte{48, 50, 48, 102, 55, 53, 53, 99, 51, 99, 48, 56, 50, 48, 48, 48}
|
||||
got, _ := id.Encode()
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("got %s not equal to wanted %s", string(got), string(want))
|
||||
}
|
||||
if id.String() != "020f755c3c082000" {
|
||||
t.Errorf("expecting string representation to contain the right value")
|
||||
}
|
||||
if !id.Valid() {
|
||||
t.Errorf("expecting ID to be a valid one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
var id platform.ID
|
||||
if _, err := id.Encode(); err == nil {
|
||||
t.Errorf("encoding an invalid ID should not be possible")
|
||||
}
|
||||
|
||||
id.DecodeFromString("5ca1ab1eba5eba11")
|
||||
want := []byte{53, 99, 97, 49, 97, 98, 49, 101, 98, 97, 53, 101, 98, 97, 49, 49}
|
||||
got, _ := id.Encode()
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("encoding error")
|
||||
}
|
||||
if id.String() != "5ca1ab1eba5eba11" {
|
||||
t.Errorf("expecting string representation to contain the right value")
|
||||
}
|
||||
if !id.Valid() {
|
||||
t.Errorf("expecting ID to be a valid one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFromAllZeros(t *testing.T) {
|
||||
var id platform.ID
|
||||
err := id.Decode(make([]byte, platform.IDLength))
|
||||
if err == nil {
|
||||
t.Errorf("expecting all zeros ID to not be a valid ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFromShorterString(t *testing.T) {
|
||||
var id platform.ID
|
||||
err := id.DecodeFromString("020f75")
|
||||
if err == nil {
|
||||
t.Errorf("expecting shorter inputs to error")
|
||||
}
|
||||
if id.String() != "" {
|
||||
t.Errorf("expecting invalid ID to be serialized into empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFromLongerString(t *testing.T) {
|
||||
var id platform.ID
|
||||
err := id.DecodeFromString("020f755c3c082000aaa")
|
||||
if err == nil {
|
||||
t.Errorf("expecting shorter inputs to error")
|
||||
}
|
||||
if id.String() != "" {
|
||||
t.Errorf("expecting invalid ID to be serialized into empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFromEmptyString(t *testing.T) {
|
||||
var id platform.ID
|
||||
err := id.DecodeFromString("")
|
||||
if err == nil {
|
||||
t.Errorf("expecting empty inputs to error")
|
||||
}
|
||||
if id.String() != "" {
|
||||
t.Errorf("expecting invalid ID to be serialized into empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalling(t *testing.T) {
|
||||
var id0 platform.ID
|
||||
_, err := json.Marshal(id0)
|
||||
if err == nil {
|
||||
t.Errorf("expecting empty ID to not be a valid one")
|
||||
}
|
||||
|
||||
init := "ca55e77eca55e77e"
|
||||
id1, err := platform.IDFromString(init)
|
||||
if err != nil {
|
||||
t.Errorf(err.Error())
|
||||
}
|
||||
|
||||
serialized, err := json.Marshal(id1)
|
||||
if err != nil {
|
||||
t.Errorf(err.Error())
|
||||
}
|
||||
|
||||
var id2 platform.ID
|
||||
json.Unmarshal(serialized, &id2)
|
||||
|
||||
bytes1, _ := id1.Encode()
|
||||
bytes2, _ := id2.Encode()
|
||||
|
||||
if !bytes.Equal(bytes1, bytes2) {
|
||||
t.Errorf("error marshalling/unmarshalling ID")
|
||||
}
|
||||
|
||||
// When used as a map key, IDs must use their string encoding.
|
||||
// If you only implement json.Marshaller, they will be encoded with Go's default integer encoding.
|
||||
b, err := json.Marshal(map[platform.ID]int{0x1234: 5678})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
const exp = `{"0000000000001234":5678}`
|
||||
if string(b) != exp {
|
||||
t.Errorf("expected map to json.Marshal as %s; got %s", exp, string(b))
|
||||
}
|
||||
|
||||
var idMap map[platform.ID]int
|
||||
if err := json.Unmarshal(b, &idMap); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(idMap) != 1 {
|
||||
t.Errorf("expected length 1, got %d", len(idMap))
|
||||
}
|
||||
if idMap[0x1234] != 5678 {
|
||||
t.Errorf("unmarshalled incorrectly; exp 0x1234:5678, got %v", idMap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValid(t *testing.T) {
|
||||
var id platform.ID
|
||||
if id.Valid() {
|
||||
t.Errorf("expecting initial ID to be invalid")
|
||||
}
|
||||
|
||||
if platform.InvalidID() != 0 {
|
||||
t.Errorf("expecting invalid ID to return a zero ID, thus invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestID_GoString(t *testing.T) {
|
||||
type idGoStringTester struct {
|
||||
ID platform.ID
|
||||
}
|
||||
var x idGoStringTester
|
||||
|
||||
const idString = "02def021097c6000"
|
||||
if err := x.ID.DecodeFromString(idString); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sharpV := fmt.Sprintf("%#v", x)
|
||||
want := `platform_test.idGoStringTester{ID:"` + idString + `"}`
|
||||
if sharpV != want {
|
||||
t.Fatalf("bad GoString: got %q, want %q", sharpV, want)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIDEncode(b *testing.B) {
|
||||
var id platform.ID
|
||||
id.DecodeFromString("5ca1ab1eba5eba11")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b, _ := id.Encode()
|
||||
_ = b
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIDDecode(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var id platform.ID
|
||||
id.DecodeFromString("5ca1ab1eba5eba11")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
package prom_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/prom"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RandomHandler implements an HTTP endpoint that prints a random float,
|
||||
// and it tracks prometheus metrics about the numbers it returns.
|
||||
type RandomHandler struct {
|
||||
// Cumulative sum of values served.
|
||||
valueCounter prometheus.Counter
|
||||
|
||||
// Total times page served.
|
||||
serveCounter prometheus.Counter
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.Handler = (*RandomHandler)(nil)
|
||||
_ prom.PrometheusCollector = (*RandomHandler)(nil)
|
||||
)
|
||||
|
||||
func NewRandomHandler() *RandomHandler {
|
||||
return &RandomHandler{
|
||||
valueCounter: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "value_counter",
|
||||
Help: "Cumulative sum of values served.",
|
||||
}),
|
||||
serveCounter: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "serve_counter",
|
||||
Help: "Counter of times page has been served.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP serves a random float value and updates rh's internal metrics.
|
||||
func (rh *RandomHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Increment serveCounter every time we serve a page.
|
||||
rh.serveCounter.Inc()
|
||||
|
||||
n := rand.Float64()
|
||||
// Track the cumulative values served.
|
||||
rh.valueCounter.Add(n)
|
||||
|
||||
fmt.Fprintf(w, "%v", n)
|
||||
}
|
||||
|
||||
// PrometheusCollectors implements prom.PrometheusCollector.
|
||||
func (rh *RandomHandler) PrometheusCollectors() []prometheus.Collector {
|
||||
return []prometheus.Collector{rh.valueCounter, rh.serveCounter}
|
||||
}
|
||||
|
||||
func Example() {
|
||||
// A collection of endpoints and http.Handlers.
|
||||
handlers := map[string]http.Handler{
|
||||
"/random": NewRandomHandler(),
|
||||
"/time": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, time.Now().String())
|
||||
}),
|
||||
}
|
||||
|
||||
// Use a local registry, not the global registry in the prometheus package.
|
||||
reg := prom.NewRegistry(zap.NewNop())
|
||||
|
||||
// Build the mux out of handlers from above.
|
||||
mux := http.NewServeMux()
|
||||
for path, h := range handlers {
|
||||
mux.Handle(path, h)
|
||||
|
||||
// Only register those handlers which implement prom.PrometheusCollector.
|
||||
if pc, ok := h.(prom.PrometheusCollector); ok {
|
||||
reg.MustRegister(pc.PrometheusCollectors()...)
|
||||
}
|
||||
}
|
||||
|
||||
// Add metrics to registry.
|
||||
mux.Handle("/metrics", reg.HTTPHandler())
|
||||
|
||||
http.ListenAndServe("localhost:8080", mux)
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
// Package promtest provides helpers for parsing and extracting prometheus metrics.
|
||||
// These functions are only intended to be called from test files,
|
||||
// as there is a dependency on the standard library testing package.
|
||||
package promtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
// FromHTTPResponse parses the prometheus metrics from the given *http.Response.
|
||||
// It relies on properly set response headers to correctly parse.
|
||||
// It will unconditionally close the response body.
|
||||
//
|
||||
// This is particularly helpful when testing the output of the /metrics endpoint of a service.
|
||||
// However, for comprehensive testing of metrics, it usually makes more sense to
|
||||
// add collectors to a registry and call Registry.Gather to get the metrics without involving HTTP.
|
||||
func FromHTTPResponse(r *http.Response) ([]*dto.MetricFamily, error) {
|
||||
defer r.Body.Close()
|
||||
|
||||
dec := expfmt.NewDecoder(r.Body, expfmt.ResponseFormat(r.Header))
|
||||
var mfs []*dto.MetricFamily
|
||||
for {
|
||||
mf := new(dto.MetricFamily)
|
||||
if err := dec.Decode(mf); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
mfs = append(mfs, mf)
|
||||
}
|
||||
|
||||
return mfs, nil
|
||||
}
|
||||
|
||||
// FindMetric iterates through mfs to find the first metric family matching name.
|
||||
// If a metric family matches, then the metrics inside the family are searched,
|
||||
// and the first metric whose labels match the given labels are returned.
|
||||
// If no matches are found, FindMetric returns nil.
|
||||
//
|
||||
// FindMetric assumes that the labels on the metric family are well formed,
|
||||
// i.e. there are no duplicate label names, and the label values are not empty strings.
|
||||
func FindMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric {
|
||||
_, m := findMetric(mfs, name, labels)
|
||||
return m
|
||||
}
|
||||
|
||||
// MustFindMetric returns the matching metric, or if no matching metric could be found,
|
||||
// it calls tb.Log with helpful output of what was actually available, before calling tb.FailNow.
|
||||
func MustFindMetric(tb testing.TB, mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric {
|
||||
tb.Helper()
|
||||
|
||||
fam, m := findMetric(mfs, name, labels)
|
||||
if fam == nil {
|
||||
tb.Logf("metric family with name %q not found", name)
|
||||
tb.Log("available names:")
|
||||
for _, mf := range mfs {
|
||||
tb.Logf("\t%s", mf.GetName())
|
||||
}
|
||||
tb.FailNow()
|
||||
return nil // Need an explicit return here for test.
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
tb.Logf("found metric family with name %q, but metric with labels %v not found", name, labels)
|
||||
tb.Logf("available labels on metric family %q:", name)
|
||||
|
||||
for _, m := range fam.Metric {
|
||||
pairs := make([]string, len(m.Label))
|
||||
for i, l := range m.Label {
|
||||
pairs[i] = fmt.Sprintf("%q: %q", l.GetName(), l.GetValue())
|
||||
}
|
||||
tb.Logf("\t%s", strings.Join(pairs, ", "))
|
||||
}
|
||||
tb.FailNow()
|
||||
return nil // Need an explicit return here for test.
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// findMetric is a helper that returns the matching family and the matching metric.
|
||||
// The exported FindMetric function specifically only finds the metric, not the family,
|
||||
// but for test it is more helpful to identify whether the family was matched.
|
||||
func findMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) (*dto.MetricFamily, *dto.Metric) {
|
||||
var fam *dto.MetricFamily
|
||||
|
||||
for _, mf := range mfs {
|
||||
if mf.GetName() == name {
|
||||
fam = mf
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fam == nil {
|
||||
// No family matching the name.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, m := range fam.Metric {
|
||||
if len(m.Label) != len(labels) {
|
||||
continue
|
||||
}
|
||||
|
||||
match := true
|
||||
for _, l := range m.Label {
|
||||
if labels[l.GetName()] != l.GetValue() {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
|
||||
// All labels matched.
|
||||
return fam, m
|
||||
}
|
||||
|
||||
// Didn't find a metric whose labels all matched.
|
||||
return fam, nil
|
||||
}
|
||||
|
||||
// MustGather calls g.Gather and calls tb.Fatal if there was an error.
|
||||
func MustGather(tb testing.TB, g prometheus.Gatherer) []*dto.MetricFamily {
|
||||
tb.Helper()
|
||||
|
||||
mfs, err := g.Gather()
|
||||
if err != nil {
|
||||
tb.Fatalf("error while gathering metrics: %v", err)
|
||||
return nil // Need an explicit return here for test.
|
||||
}
|
||||
|
||||
return mfs
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
package promtest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/prom"
|
||||
"github.com/influxdata/influxdb/kit/prom/promtest"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func helperCollectors() []prometheus.Collector {
|
||||
myCounter := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "my",
|
||||
Subsystem: "random",
|
||||
Name: "counter",
|
||||
Help: "Just a random counter.",
|
||||
})
|
||||
myGaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: "my",
|
||||
Subsystem: "random",
|
||||
Name: "gaugevec",
|
||||
Help: "Just a random gauge vector.",
|
||||
}, []string{"label1", "label2"})
|
||||
|
||||
myCounter.Inc()
|
||||
myGaugeVec.WithLabelValues("one", "two").Set(3)
|
||||
|
||||
return []prometheus.Collector{myCounter, myGaugeVec}
|
||||
}
|
||||
|
||||
func TestFindMetric(t *testing.T) {
|
||||
reg := prom.NewRegistry(zaptest.NewLogger(t))
|
||||
reg.MustRegister(helperCollectors()...)
|
||||
|
||||
mfs, err := reg.Gather()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil)
|
||||
if got := c.GetCounter().GetValue(); got != 1 {
|
||||
t.Fatalf("expected counter to be 1, got %v", got)
|
||||
}
|
||||
|
||||
g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"})
|
||||
if got := g.GetGauge().GetValue(); got != 3 {
|
||||
t.Fatalf("expected gauge to be 3, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeT helps us to assert that MustFindMetric calls FailNow when the metric isn't found.
|
||||
type fakeT struct {
|
||||
// Embed a T so we don't have to reimplement everything.
|
||||
// It's fine to leave T nil - fakeT will panic if calling a method we haven't implemented.
|
||||
*testing.T
|
||||
|
||||
logBuf bytes.Buffer
|
||||
|
||||
failed bool
|
||||
}
|
||||
|
||||
func (t *fakeT) Helper() {}
|
||||
|
||||
func (t *fakeT) Log(args ...interface{}) {
|
||||
fmt.Fprint(&t.logBuf, args...)
|
||||
t.logBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
func (t *fakeT) Logf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(&t.logBuf, format, args...)
|
||||
t.logBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
func (t *fakeT) FailNow() {
|
||||
t.failed = true
|
||||
}
|
||||
|
||||
func (t *fakeT) Fatalf(format string, args ...interface{}) {
|
||||
t.Logf(format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
func TestMustFindMetric(t *testing.T) {
|
||||
reg := prom.NewRegistry(zaptest.NewLogger(t))
|
||||
reg.MustRegister(helperCollectors()...)
|
||||
|
||||
mfs, err := reg.Gather()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ft := new(fakeT)
|
||||
|
||||
// Doesn't log when metric is found.
|
||||
_ = promtest.MustFindMetric(ft, mfs, "my_random_counter", nil)
|
||||
if ft.failed {
|
||||
t.Fatalf("MustFindMetric failed when it should not have. message: %s", ft.logBuf.String())
|
||||
}
|
||||
|
||||
// Logs and fails when family name not found.
|
||||
ft = new(fakeT)
|
||||
_ = promtest.MustFindMetric(ft, mfs, "missing_name", nil)
|
||||
if !ft.failed {
|
||||
t.Fatal("MustFindMetric should have failed but didn't")
|
||||
}
|
||||
logged := ft.logBuf.String()
|
||||
if !strings.Contains(logged, `name "missing_name" not found`) {
|
||||
t.Fatalf("did not log the looked up name which was not found. message: %s", logged)
|
||||
}
|
||||
if !strings.Contains(logged, "my_random_counter") || !strings.Contains(logged, "my_random_gaugevec") {
|
||||
t.Fatalf("did not log the available metric names. message: %s", logged)
|
||||
}
|
||||
|
||||
// Logs and fails when family name found but metric labels mismatch.
|
||||
ft = new(fakeT)
|
||||
_ = promtest.MustFindMetric(ft, mfs, "my_random_counter", map[string]string{"unknown": "label"})
|
||||
if !ft.failed {
|
||||
t.Fatal("MustFindMetric should have failed but didn't")
|
||||
}
|
||||
|
||||
ft = new(fakeT)
|
||||
_ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"unknown": "label"})
|
||||
if !ft.failed {
|
||||
t.Fatal("MustFindMetric should have failed but didn't")
|
||||
}
|
||||
logged = ft.logBuf.String()
|
||||
if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) {
|
||||
t.Fatalf("did not log the available label names. message: %s", logged)
|
||||
}
|
||||
|
||||
ft = new(fakeT)
|
||||
_ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two", "label3": "imaginary"})
|
||||
if !ft.failed {
|
||||
t.Fatal("MustFindMetric should have failed but didn't")
|
||||
}
|
||||
logged = ft.logBuf.String()
|
||||
if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) {
|
||||
t.Fatalf("did not log the available label names. message: %s", logged)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustGather(t *testing.T) {
|
||||
expErr := errors.New("failed to gather")
|
||||
g := prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
|
||||
return nil, expErr
|
||||
})
|
||||
|
||||
ft := new(fakeT)
|
||||
_ = promtest.MustGather(ft, g)
|
||||
if !ft.failed {
|
||||
t.Fatal("MustGather should have failed but didn't")
|
||||
}
|
||||
logged := ft.logBuf.String()
|
||||
if !strings.HasPrefix(logged, "error while gathering metrics:") || !strings.Contains(logged, expErr.Error()) {
|
||||
t.Fatalf("did not log the expected error message: %s", logged)
|
||||
}
|
||||
|
||||
expMF := []*dto.MetricFamily{} // Use a non-nil, zero-length slice for a simple-ish check.
|
||||
g = prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
|
||||
return expMF, nil
|
||||
})
|
||||
ft = new(fakeT)
|
||||
gotMF := promtest.MustGather(ft, g)
|
||||
if ft.failed {
|
||||
t.Fatalf("MustGather should not have failed")
|
||||
}
|
||||
if gotMF == nil || len(gotMF) != 0 {
|
||||
t.Fatalf("exp: %v, got: %v", expMF, gotMF)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromHTTPResponse(t *testing.T) {
|
||||
reg := prom.NewRegistry(zaptest.NewLogger(t))
|
||||
reg.MustRegister(helperCollectors()...)
|
||||
|
||||
s := httptest.NewServer(reg.HTTPHandler())
|
||||
defer s.Close()
|
||||
|
||||
resp, err := http.Get(s.URL) // Didn't specify a path for the handler, so any path should be fine.
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfs, err := promtest.FromHTTPResponse(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(mfs) != 2 {
|
||||
t.Fatalf("expected 2 metrics but got %d", len(mfs))
|
||||
}
|
||||
|
||||
c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil)
|
||||
if got := c.GetCounter().GetValue(); got != 1 {
|
||||
t.Fatalf("expected counter to be 1, got %v", got)
|
||||
}
|
||||
|
||||
g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"})
|
||||
if got := g.GetGauge().GetValue(); got != 3 {
|
||||
t.Fatalf("expected gauge to be 3, got %v", got)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
// Package prom provides a wrapper around a prometheus metrics registry
|
||||
// so that all services are unified in how they expose prometheus metrics.
|
||||
package prom
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PrometheusCollector is the interface for a type to expose prometheus metrics.
|
||||
// This interface is provided as a convention, so that you can optionally check
|
||||
// if a type implements it and then pass its collectors to (*Registry).MustRegister.
|
||||
type PrometheusCollector interface {
|
||||
// PrometheusCollectors returns a slice of prometheus collectors
|
||||
// containing metrics for the underlying instance.
|
||||
PrometheusCollectors() []prometheus.Collector
|
||||
}
|
||||
|
||||
// Registry embeds a prometheus registry and adds a couple convenience methods.
|
||||
type Registry struct {
|
||||
*prometheus.Registry
|
||||
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
// NewRegistry returns a new registry.
|
||||
func NewRegistry(log *zap.Logger) *Registry {
|
||||
return &Registry{
|
||||
Registry: prometheus.NewRegistry(),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler for the registry,
|
||||
// so that the /metrics HTTP handler is uniformly configured across all apps in the platform.
|
||||
func (r *Registry) HTTPHandler() http.Handler {
|
||||
opts := promhttp.HandlerOpts{
|
||||
ErrorLog: promLogger{r: r},
|
||||
// TODO(mr): decide if we want to set MaxRequestsInFlight or Timeout.
|
||||
}
|
||||
return promhttp.HandlerFor(r.Registry, opts)
|
||||
}
|
||||
|
||||
// promLogger satisfies the promhttp.logger interface with the registry.
|
||||
// Because normal usage is that WithLogger is called after HTTPHandler,
|
||||
// we refer to the Registry rather than its logger.
|
||||
type promLogger struct {
|
||||
r *Registry
|
||||
}
|
||||
|
||||
var _ promhttp.Logger = (*promLogger)(nil)
|
||||
|
||||
// Println implements promhttp.logger.
|
||||
func (pl promLogger) Println(v ...interface{}) {
|
||||
pl.r.log.Sugar().Info(v...)
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package prom_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/prom"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func TestRegistry_Logger(t *testing.T) {
|
||||
core, logs := observer.New(zap.DebugLevel)
|
||||
reg := prom.NewRegistry(zap.New(core))
|
||||
|
||||
// Normal use: HTTP handler is created immediately...
|
||||
s := httptest.NewServer(reg.HTTPHandler())
|
||||
defer s.Close()
|
||||
|
||||
// Force an error with a fake collector.
|
||||
reg.MustRegister(errorCollector{})
|
||||
resp, err := http.Get(s.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
foundLog := false
|
||||
for _, le := range logs.All() {
|
||||
if strings.Contains(le.Message, "invalid metric from errorCollector") {
|
||||
foundLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundLog {
|
||||
t.Fatalf("registry logger did not log error from metric collection")
|
||||
}
|
||||
}
|
||||
|
||||
type errorCollector struct{}
|
||||
|
||||
var _ prometheus.Collector = errorCollector{}
|
||||
|
||||
var ecDesc = prometheus.NewDesc("error_collector_desc", "A required description for the error collector", nil, nil)
|
||||
|
||||
func (errorCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- ecDesc
|
||||
}
|
||||
|
||||
func (errorCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
ch <- prometheus.NewInvalidMetric(
|
||||
ecDesc,
|
||||
errors.New("invalid metric from errorCollector"),
|
||||
)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package signals
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// WithSignals returns a context that is canceled with any signal in sigs.
|
||||
func WithSignals(ctx context.Context, sigs ...os.Signal) context.Context {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, sigs...)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-sigCh:
|
||||
return
|
||||
}
|
||||
}()
|
||||
return ctx
|
||||
}
|
||||
|
||||
// WithStandardSignals cancels the context on os.Interrupt, syscall.SIGTERM.
|
||||
func WithStandardSignals(ctx context.Context) context.Context {
|
||||
return WithSignals(ctx, os.Interrupt, syscall.SIGTERM)
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package signals
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ExampleWithSignals() {
|
||||
ctx := WithSignals(context.Background(), syscall.SIGUSR1)
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond) // after some time SIGUSR1 is sent
|
||||
// mimicking a signal from the outside
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
fmt.Println("finished")
|
||||
// Output:
|
||||
// finished
|
||||
}
|
||||
|
||||
func Example_withUnregisteredSignals() {
|
||||
dctx, cancel := context.WithTimeout(context.TODO(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
|
||||
ctx := WithSignals(dctx, syscall.SIGUSR1)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond) // after some time SIGUSR2 is sent
|
||||
// mimicking a signal from the outside, WithSignals will not handle it
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
fmt.Println("finished")
|
||||
// Output:
|
||||
// finished
|
||||
}
|
||||
|
||||
func TestWithSignals(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
sigs []os.Signal
|
||||
wantSignal bool
|
||||
}{
|
||||
{
|
||||
name: "sending signal SIGUSR2 should exit context.",
|
||||
ctx: context.Background(),
|
||||
sigs: []os.Signal{syscall.SIGUSR2},
|
||||
wantSignal: true,
|
||||
},
|
||||
{
|
||||
name: "sending signal SIGUSR2 should NOT exit context.",
|
||||
ctx: context.Background(),
|
||||
sigs: []os.Signal{syscall.SIGUSR1},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := WithSignals(tt.ctx, tt.sigs...)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
|
||||
timer := time.NewTimer(500 * time.Millisecond)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !tt.wantSignal {
|
||||
t.Errorf("unexpected exit with signal")
|
||||
}
|
||||
case <-timer.C:
|
||||
if tt.wantSignal {
|
||||
t.Errorf("expected to exit with signal but did not")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package testing
|
||||
|
||||
import (
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/uber/jaeger-client-go"
|
||||
)
|
||||
|
||||
// SetupInMemoryTracing sets the global tracer to an in memory Jaeger instance for testing.
|
||||
// The returned function should be deferred by the caller to tear down this setup after testing is complete.
|
||||
func SetupInMemoryTracing(name string) func() {
|
||||
var (
|
||||
old = opentracing.GlobalTracer()
|
||||
tracer, closer = jaeger.NewTracer(name,
|
||||
jaeger.NewConstSampler(true),
|
||||
jaeger.NewInMemoryReporter(),
|
||||
)
|
||||
)
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
return func() {
|
||||
_ = closer.Close()
|
||||
opentracing.SetGlobalTracer(old)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/uber/jaeger-client-go"
|
||||
|
||||
"github.com/influxdata/httprouter"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/opentracing/opentracing-go/log"
|
||||
)
|
||||
|
||||
// LogError adds a span log for an error.
|
||||
// Returns unchanged error, so useful to wrap as in:
|
||||
// return 0, tracing.LogError(err)
|
||||
func LogError(span opentracing.Span, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get caller frame.
|
||||
var pcs [1]uintptr
|
||||
n := runtime.Callers(2, pcs[:])
|
||||
if n < 1 {
|
||||
span.LogFields(log.Error(err))
|
||||
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
|
||||
return err
|
||||
}
|
||||
|
||||
file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0])
|
||||
span.LogFields(log.String("filename", file), log.Int("line", line), log.Error(err))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// InjectToHTTPRequest adds tracing headers to an HTTP request.
|
||||
// Easier than adding this boilerplate everywhere.
|
||||
func InjectToHTTPRequest(span opentracing.Span, req *http.Request) {
|
||||
err := opentracing.GlobalTracer().Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
|
||||
if err != nil {
|
||||
LogError(span, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractFromHTTPRequest gets a child span of the parent referenced in HTTP request headers.
|
||||
// Returns the request with updated tracing context.
|
||||
// Easier than adding this boilerplate everywhere.
|
||||
func ExtractFromHTTPRequest(req *http.Request, handlerName string) (opentracing.Span, *http.Request) {
|
||||
spanContext, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
|
||||
if err != nil {
|
||||
span, ctx := opentracing.StartSpanFromContext(req.Context(), "request")
|
||||
annotateSpan(span, handlerName, req)
|
||||
|
||||
_ = LogError(span, err)
|
||||
|
||||
return span, req.WithContext(ctx)
|
||||
}
|
||||
|
||||
span := opentracing.StartSpan("request", opentracing.ChildOf(spanContext), ext.RPCServerOption(spanContext))
|
||||
annotateSpan(span, handlerName, req)
|
||||
|
||||
return span, req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
||||
}
|
||||
|
||||
func annotateSpan(span opentracing.Span, handlerName string, req *http.Request) {
|
||||
if route := httprouter.MatchedRouteFromContext(req.Context()); route != "" {
|
||||
span.SetTag("route", route)
|
||||
}
|
||||
span.SetTag("method", req.Method)
|
||||
if ctx := chi.RouteContext(req.Context()); ctx != nil {
|
||||
span.SetTag("route", ctx.RoutePath)
|
||||
}
|
||||
span.SetTag("handler", handlerName)
|
||||
span.LogKV("path", req.URL.Path)
|
||||
}
|
||||
|
||||
// span is a simple wrapper around opentracing.Span in order to
|
||||
// get access to the duration of the span for metrics reporting.
|
||||
type Span struct {
|
||||
opentracing.Span
|
||||
start time.Time
|
||||
Duration time.Duration
|
||||
hist prometheus.Observer
|
||||
gauge prometheus.Gauge
|
||||
}
|
||||
|
||||
func StartSpanFromContextWithPromMetrics(ctx context.Context, operationName string, hist prometheus.Observer, gauge prometheus.Gauge, opts ...opentracing.StartSpanOption) (*Span, context.Context) {
|
||||
start := time.Now()
|
||||
s, sctx := StartSpanFromContextWithOperationName(ctx, operationName, opentracing.StartTime(start))
|
||||
gauge.Inc()
|
||||
return &Span{s, start, 0, hist, gauge}, sctx
|
||||
}
|
||||
|
||||
func (s *Span) Finish() {
|
||||
finish := time.Now()
|
||||
s.Duration = finish.Sub(s.start)
|
||||
s.Span.FinishWithOptions(opentracing.FinishOptions{
|
||||
FinishTime: finish,
|
||||
})
|
||||
s.hist.Observe(s.Duration.Seconds())
|
||||
s.gauge.Dec()
|
||||
}
|
||||
|
||||
// StartSpanFromContext is an improved opentracing.StartSpanFromContext.
|
||||
// Uses the calling function as the operation name, and logs the filename and line number.
|
||||
//
|
||||
// Passing nil context induces panic.
|
||||
// Context without parent span reference triggers root span construction.
|
||||
// This function never returns nil values.
|
||||
//
|
||||
// Performance
|
||||
//
|
||||
// This function incurs a small performance penalty, roughly 1000 ns/op, 376 B/op, 6 allocs/op.
|
||||
// Jaeger timestamp and duration precision is only µs, so this is pretty negligible.
|
||||
//
|
||||
// Alternatives
|
||||
//
|
||||
// If this performance penalty is too much, try these, which are also demonstrated in benchmark tests:
|
||||
// // Create a root span
|
||||
// span := opentracing.StartSpan("operation name")
|
||||
// ctx := opentracing.ContextWithSpan(context.Background(), span)
|
||||
//
|
||||
// // Create a child span
|
||||
// span := opentracing.StartSpan("operation name", opentracing.ChildOf(sc))
|
||||
// ctx := opentracing.ContextWithSpan(context.Background(), span)
|
||||
//
|
||||
// // Sugar to create a child span
|
||||
// span, ctx := opentracing.StartSpanFromContext(ctx, "operation name")
|
||||
func StartSpanFromContext(ctx context.Context, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
|
||||
if ctx == nil {
|
||||
panic("StartSpanFromContext called with nil context")
|
||||
}
|
||||
|
||||
// Get caller frame.
|
||||
var pcs [1]uintptr
|
||||
n := runtime.Callers(2, pcs[:])
|
||||
if n < 1 {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "unknown", opts...)
|
||||
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
|
||||
return span, ctx
|
||||
}
|
||||
fn := runtime.FuncForPC(pcs[0])
|
||||
name := fn.Name()
|
||||
if lastSlash := strings.LastIndexByte(name, '/'); lastSlash > 0 {
|
||||
name = name[lastSlash+1:]
|
||||
}
|
||||
|
||||
var span opentracing.Span
|
||||
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
|
||||
// Create a child span.
|
||||
opts = append(opts, opentracing.ChildOf(parentSpan.Context()))
|
||||
span = opentracing.StartSpan(name, opts...)
|
||||
} else {
|
||||
// Create a root span.
|
||||
span = opentracing.StartSpan(name)
|
||||
}
|
||||
// New context references this span, not the parent (if there was one).
|
||||
ctx = opentracing.ContextWithSpan(ctx, span)
|
||||
|
||||
file, line := fn.FileLine(pcs[0])
|
||||
span.LogFields(log.String("filename", file), log.Int("line", line))
|
||||
|
||||
return span, ctx
|
||||
}
|
||||
|
||||
// StartSpanFromContextWithOperationName is like StartSpanFromContext, but the caller determines the operation name.
|
||||
func StartSpanFromContextWithOperationName(ctx context.Context, operationName string, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
|
||||
if ctx == nil {
|
||||
panic("StartSpanFromContextWithOperationName called with nil context")
|
||||
}
|
||||
|
||||
// Get caller frame.
|
||||
var pcs [1]uintptr
|
||||
n := runtime.Callers(2, pcs[:])
|
||||
if n < 1 {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, operationName, opts...)
|
||||
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
|
||||
return span, ctx
|
||||
}
|
||||
file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0])
|
||||
|
||||
var span opentracing.Span
|
||||
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
|
||||
opts = append(opts, opentracing.ChildOf(parentSpan.Context()))
|
||||
// Create a child span.
|
||||
span = opentracing.StartSpan(operationName, opts...)
|
||||
} else {
|
||||
// Create a root span.
|
||||
span = opentracing.StartSpan(operationName, opts...)
|
||||
}
|
||||
// New context references this span, not the parent (if there was one).
|
||||
ctx = opentracing.ContextWithSpan(ctx, span)
|
||||
|
||||
span.LogFields(log.String("filename", file), log.Int("line", line))
|
||||
|
||||
return span, ctx
|
||||
}
|
||||
|
||||
// InfoFromSpan returns the traceID and if it was sampled from the span, given it is a jaeger span.
|
||||
// It returns whether a span associated to the context has been found.
|
||||
func InfoFromSpan(span opentracing.Span) (traceID string, sampled bool, found bool) {
|
||||
if spanContext, ok := span.Context().(jaeger.SpanContext); ok {
|
||||
traceID = spanContext.TraceID().String()
|
||||
sampled = spanContext.IsSampled()
|
||||
return traceID, sampled, true
|
||||
}
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
// InfoFromContext returns the traceID and if it was sampled from the Jaeger span
|
||||
// found in the given context. It returns whether a span associated to the context has been found.
|
||||
func InfoFromContext(ctx context.Context) (traceID string, sampled bool, found bool) {
|
||||
if span := opentracing.SpanFromContext(ctx); span != nil {
|
||||
return InfoFromSpan(span)
|
||||
}
|
||||
return "", false, false
|
||||
}
|
|
@ -0,0 +1,333 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/influxdata/httprouter"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/mocktracer"
|
||||
)
|
||||
|
||||
func TestInjectAndExtractHTTPRequest(t *testing.T) {
|
||||
tracer := mocktracer.New()
|
||||
|
||||
oldTracer := opentracing.GlobalTracer()
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
defer opentracing.SetGlobalTracer(oldTracer)
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, "http://localhost/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
span := tracer.StartSpan("operation name")
|
||||
|
||||
InjectToHTTPRequest(span, request)
|
||||
gotSpan, _ := ExtractFromHTTPRequest(request, "MyStruct")
|
||||
|
||||
if span.(*mocktracer.MockSpan).SpanContext.TraceID != gotSpan.(*mocktracer.MockSpan).SpanContext.TraceID {
|
||||
t.Error("injected and extracted traceIDs not equal")
|
||||
}
|
||||
|
||||
if span.(*mocktracer.MockSpan).SpanContext.SpanID != gotSpan.(*mocktracer.MockSpan).ParentID {
|
||||
t.Error("injected span ID does not match extracted span parent ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHTTPRequest(t *testing.T) {
|
||||
var (
|
||||
tracer = mocktracer.New()
|
||||
oldTracer = opentracing.GlobalTracer()
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
defer opentracing.SetGlobalTracer(oldTracer)
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
handlerName string
|
||||
path string
|
||||
ctx context.Context
|
||||
tags map[string]interface{}
|
||||
method string
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
handlerName: "WriteHandler",
|
||||
ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/write"),
|
||||
method: http.MethodGet,
|
||||
path: "/api/v2/write",
|
||||
tags: map[string]interface{}{
|
||||
"route": "/api/v2/write",
|
||||
"handler": "WriteHandler",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path bucket handler",
|
||||
handlerName: "BucketHandler",
|
||||
ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/buckets/:bucket_id"),
|
||||
path: "/api/v2/buckets/12345",
|
||||
method: http.MethodGet,
|
||||
tags: map[string]interface{}{
|
||||
"route": "/api/v2/buckets/:bucket_id",
|
||||
"handler": "BucketHandler",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path bucket handler (chi)",
|
||||
handlerName: "BucketHandler",
|
||||
ctx: context.WithValue(
|
||||
ctx,
|
||||
chi.RouteCtxKey,
|
||||
&chi.Context{RoutePath: "/api/v2/buckets/:bucket_id", RouteMethod: "GET"},
|
||||
),
|
||||
path: "/api/v2/buckets/12345",
|
||||
method: http.MethodGet,
|
||||
tags: map[string]interface{}{
|
||||
"route": "/api/v2/buckets/:bucket_id",
|
||||
"method": "GET",
|
||||
"handler": "BucketHandler",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
handlerName: "Home",
|
||||
ctx: ctx,
|
||||
method: http.MethodGet,
|
||||
tags: map[string]interface{}{
|
||||
"handler": "Home",
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
request, err := http.NewRequest(test.method, "http://localhost"+test.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
span := tracer.StartSpan("operation name")
|
||||
|
||||
InjectToHTTPRequest(span, request)
|
||||
gotSpan, _ := ExtractFromHTTPRequest(request.WithContext(test.ctx), test.handlerName)
|
||||
|
||||
if op := gotSpan.(*mocktracer.MockSpan).OperationName; op != "request" {
|
||||
t.Fatalf("operation name %q != request", op)
|
||||
}
|
||||
|
||||
tags := gotSpan.(*mocktracer.MockSpan).Tags()
|
||||
for k, v := range test.tags {
|
||||
found, ok := tags[k]
|
||||
if !ok {
|
||||
t.Errorf("tag not found in span %q", k)
|
||||
continue
|
||||
}
|
||||
|
||||
if found != v {
|
||||
t.Errorf("expected %v, found %v for tag %q", v, found, k)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSpanFromContext(t *testing.T) {
|
||||
tracer := mocktracer.New()
|
||||
|
||||
oldTracer := opentracing.GlobalTracer()
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
defer opentracing.SetGlobalTracer(oldTracer)
|
||||
|
||||
type testCase struct {
|
||||
ctx context.Context
|
||||
expectPanic bool
|
||||
expectParent bool
|
||||
}
|
||||
var testCases []testCase
|
||||
|
||||
testCases = append(testCases,
|
||||
testCase{
|
||||
ctx: nil,
|
||||
expectPanic: true,
|
||||
expectParent: false,
|
||||
},
|
||||
testCase{
|
||||
ctx: context.Background(),
|
||||
expectPanic: false,
|
||||
expectParent: false,
|
||||
})
|
||||
|
||||
parentSpan := opentracing.StartSpan("parent operation name")
|
||||
testCases = append(testCases, testCase{
|
||||
ctx: opentracing.ContextWithSpan(context.Background(), parentSpan),
|
||||
expectPanic: false,
|
||||
expectParent: true,
|
||||
})
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
var span opentracing.Span
|
||||
var ctx context.Context
|
||||
var gotPanic bool
|
||||
|
||||
func(inputCtx context.Context) {
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
gotPanic = true
|
||||
}
|
||||
}()
|
||||
span, ctx = StartSpanFromContext(inputCtx)
|
||||
}(tc.ctx)
|
||||
|
||||
if tc.expectPanic != gotPanic {
|
||||
t.Errorf("panic: expect %v got %v", tc.expectPanic, gotPanic)
|
||||
}
|
||||
if tc.expectPanic {
|
||||
// No other valid checks if panic.
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
t.Error("never expect non-nil ctx")
|
||||
}
|
||||
if span == nil {
|
||||
t.Error("never expect non-nil Span")
|
||||
}
|
||||
foundParent := span.(*mocktracer.MockSpan).ParentID != 0
|
||||
if tc.expectParent != foundParent {
|
||||
t.Errorf("parent: expect %v got %v", tc.expectParent, foundParent)
|
||||
}
|
||||
if ctx == tc.ctx {
|
||||
t.Errorf("always expect fresh context")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogErrorNil(t *testing.T) {
|
||||
tracer := mocktracer.New()
|
||||
span := tracer.StartSpan("test").(*mocktracer.MockSpan)
|
||||
|
||||
var err error
|
||||
if err2 := LogError(span, err); err2 != nil {
|
||||
t.Errorf("expected nil err, got '%s'", err2.Error())
|
||||
}
|
||||
|
||||
if len(span.Logs()) > 0 {
|
||||
t.Errorf("expected zero new span logs, got %d", len(span.Logs()))
|
||||
println(span.Logs()[0].Fields[0].Key)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
BenchmarkLocal_StartSpanFromContext-8 2000000 681 ns/op 224 B/op 4 allocs/op
|
||||
BenchmarkLocal_StartSpanFromContext_runtimeCaller-8 3000000 534 ns/op
|
||||
BenchmarkLocal_StartSpanFromContext_runtimeCallers-8 10000000 196 ns/op
|
||||
BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC-8 200000000 7.28 ns/op
|
||||
BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames-8 10000000 234 ns/op
|
||||
BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine-8 20000000 103 ns/op
|
||||
BenchmarkOpentracing_StartSpanFromContext-8 10000000 155 ns/op 96 B/op 3 allocs/op
|
||||
BenchmarkOpentracing_StartSpan_root-8 200000000 7.68 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkOpentracing_StartSpan_child-8 20000000 71.2 ns/op 48 B/op 2 allocs/op
|
||||
*/
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
parentSpan := opentracing.StartSpan("parent operation name")
|
||||
ctx := opentracing.ContextWithSpan(context.Background(), parentSpan)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
StartSpanFromContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext_runtimeCaller(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, _, _, _ = runtime.Caller(1)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext_runtimeCallers(b *testing.B) {
|
||||
var pcs [1]uintptr
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = runtime.Callers(2, pcs[:])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC(b *testing.B) {
|
||||
var pcs [1]uintptr
|
||||
_ = runtime.Callers(2, pcs[:])
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = runtime.FuncForPC(pcs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames(b *testing.B) {
|
||||
pc, _, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
b.Fatal("runtime.Caller failed")
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, _ = runtime.CallersFrames([]uintptr{pc}).Next()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine(b *testing.B) {
|
||||
var pcs [1]uintptr
|
||||
_ = runtime.Callers(2, pcs[:])
|
||||
fn := runtime.FuncForPC(pcs[0])
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, _ = fn.FileLine(pcs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpentracing_StartSpanFromContext(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
parentSpan := opentracing.StartSpan("parent operation name")
|
||||
ctx := opentracing.ContextWithSpan(context.Background(), parentSpan)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, _ = opentracing.StartSpanFromContext(ctx, "operation name")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpentracing_StartSpan_root(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = opentracing.StartSpan("operation name")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpentracing_StartSpan_child(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
parentSpan := opentracing.StartSpan("parent operation name")
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = opentracing.StartSpan("operation name", opentracing.ChildOf(parentSpan.Context()))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpentracing_ExtractFromHTTPRequest(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
req := &http.Request{
|
||||
URL: &url.URL{Path: "/api/v2/organization/12345"},
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, _ = ExtractFromHTTPRequest(req, "OrganizationHandler")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,267 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PlatformErrorCodeHeader shows the error code of platform error.
|
||||
const PlatformErrorCodeHeader = "X-Platform-Error-Code"
|
||||
|
||||
// API provides a consolidated means for handling API interface concerns.
|
||||
// Concerns such as decoding/encoding request and response bodies as well
|
||||
// as adding headers for content type and content encoding.
|
||||
type API struct {
|
||||
logger *zap.Logger
|
||||
|
||||
prettyJSON bool
|
||||
encodeGZIP bool
|
||||
|
||||
unmarshalErrFn func(encoding string, err error) error
|
||||
okErrFn func(err error) error
|
||||
errFn func(ctx context.Context, err error) (interface{}, int, error)
|
||||
}
|
||||
|
||||
// APIOptFn is a functional option for setting fields on the API type.
|
||||
type APIOptFn func(*API)
|
||||
|
||||
// WithLog sets the logger.
|
||||
func WithLog(logger *zap.Logger) APIOptFn {
|
||||
return func(api *API) {
|
||||
api.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrFn sets the err handling func for issues when writing to the response body.
|
||||
func WithErrFn(fn func(ctx context.Context, err error) (interface{}, int, error)) APIOptFn {
|
||||
return func(api *API) {
|
||||
api.errFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// WithOKErrFn is an error handler for failing validation for request bodies.
|
||||
func WithOKErrFn(fn func(err error) error) APIOptFn {
|
||||
return func(api *API) {
|
||||
api.okErrFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrettyJSON sets the json encoder to marshal indent or not.
|
||||
func WithPrettyJSON(b bool) APIOptFn {
|
||||
return func(api *API) {
|
||||
api.prettyJSON = b
|
||||
}
|
||||
}
|
||||
|
||||
// WithEncodeGZIP sets the encoder to gzip contents.
|
||||
func WithEncodeGZIP() APIOptFn {
|
||||
return func(api *API) {
|
||||
api.encodeGZIP = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnmarshalErrFn sets the error handler for errors that occur when unmarshalling
|
||||
// the request body.
|
||||
func WithUnmarshalErrFn(fn func(encoding string, err error) error) APIOptFn {
|
||||
return func(api *API) {
|
||||
api.unmarshalErrFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// NewAPI creates a new API type.
|
||||
func NewAPI(opts ...APIOptFn) *API {
|
||||
api := API{
|
||||
logger: zap.NewNop(),
|
||||
prettyJSON: true,
|
||||
unmarshalErrFn: func(encoding string, err error) error {
|
||||
return &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: fmt.Sprintf("failed to unmarshal %s: %s", encoding, err),
|
||||
}
|
||||
},
|
||||
errFn: func(ctx context.Context, err error) (interface{}, int, error) {
|
||||
msg := err.Error()
|
||||
if msg == "" {
|
||||
msg = "an internal error has occurred"
|
||||
}
|
||||
code := errors.ErrorCode(err)
|
||||
return ErrBody{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
}, ErrorCodeToStatusCode(ctx, code), nil
|
||||
},
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&api)
|
||||
}
|
||||
return &api
|
||||
}
|
||||
|
||||
// DecodeJSON decodes reader with json.
|
||||
func (a *API) DecodeJSON(r io.Reader, v interface{}) error {
|
||||
return a.decode("json", json.NewDecoder(r), v)
|
||||
}
|
||||
|
||||
// DecodeGob decodes reader with gob.
|
||||
func (a *API) DecodeGob(r io.Reader, v interface{}) error {
|
||||
return a.decode("gob", gob.NewDecoder(r), v)
|
||||
}
|
||||
|
||||
type (
|
||||
decoder interface {
|
||||
Decode(interface{}) error
|
||||
}
|
||||
|
||||
oker interface {
|
||||
OK() error
|
||||
}
|
||||
)
|
||||
|
||||
func (a *API) decode(encoding string, dec decoder, v interface{}) error {
|
||||
if err := dec.Decode(v); err != nil {
|
||||
if a != nil && a.unmarshalErrFn != nil {
|
||||
return a.unmarshalErrFn(encoding, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if vv, ok := v.(oker); ok {
|
||||
err := vv.OK()
|
||||
if a != nil && a.okErrFn != nil {
|
||||
return a.okErrFn(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Respond writes to the response writer, handling all errors in writing.
|
||||
func (a *API) Respond(w http.ResponseWriter, r *http.Request, status int, v interface{}) {
|
||||
if status == http.StatusNoContent {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
|
||||
var writer io.WriteCloser = noopCloser{Writer: w}
|
||||
// we'll double close to make sure its always closed even
|
||||
//on issues before the write
|
||||
defer writer.Close()
|
||||
|
||||
if a != nil && a.encodeGZIP {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
writer = gzip.NewWriter(w)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
// this marshal block is to catch failures before they hit the http writer.
|
||||
// default behavior for http.ResponseWriter is when body is written and no
|
||||
// status is set, it writes a 200. Or if a status is set before encoding
|
||||
// and an error occurs, there is no means to write a proper status code
|
||||
// (i.e. 500) when that is to occur. This brings that step out before
|
||||
// and then writes the data and sets the status code after marshaling
|
||||
// succeeds.
|
||||
var (
|
||||
b []byte
|
||||
err error
|
||||
)
|
||||
if a == nil || a.prettyJSON {
|
||||
b, err = json.MarshalIndent(v, "", "\t")
|
||||
} else {
|
||||
b, err = json.Marshal(v)
|
||||
}
|
||||
if err != nil {
|
||||
a.Err(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
a.write(w, writer, status, b)
|
||||
}
|
||||
|
||||
// Write allows the user to write raw bytes to the response writer. This
|
||||
// operation does not have a fail case, all failures here will be logged.
|
||||
func (a *API) Write(w http.ResponseWriter, status int, b []byte) {
|
||||
if status == http.StatusNoContent {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
|
||||
var writer io.WriteCloser = noopCloser{Writer: w}
|
||||
// we'll double close to make sure its always closed even
|
||||
//on issues before the write
|
||||
defer writer.Close()
|
||||
|
||||
if a != nil && a.encodeGZIP {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
writer = gzip.NewWriter(w)
|
||||
}
|
||||
|
||||
a.write(w, writer, status, b)
|
||||
}
|
||||
|
||||
func (a *API) write(w http.ResponseWriter, wc io.WriteCloser, status int, b []byte) {
|
||||
w.WriteHeader(status)
|
||||
if _, err := wc.Write(b); err != nil {
|
||||
a.logErr("failed to write to response writer", zap.Error(err))
|
||||
}
|
||||
|
||||
if err := wc.Close(); err != nil {
|
||||
a.logErr("failed to close response writer", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Err is used for writing an error to the response.
|
||||
func (a *API) Err(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.logErr("api error encountered", zap.Error(err))
|
||||
|
||||
v, status, err := a.errFn(r.Context(), err)
|
||||
if err != nil {
|
||||
a.logErr("failed to write err to response writer", zap.Error(err))
|
||||
a.Respond(w, r, http.StatusInternalServerError, ErrBody{
|
||||
Code: "internal error",
|
||||
Msg: "an unexpected error occurred",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if eb, ok := v.(ErrBody); ok {
|
||||
w.Header().Set(PlatformErrorCodeHeader, eb.Code)
|
||||
}
|
||||
|
||||
a.Respond(w, r, status, v)
|
||||
}
|
||||
|
||||
func (a *API) logErr(msg string, fields ...zap.Field) {
|
||||
if a == nil || a.logger == nil {
|
||||
return
|
||||
}
|
||||
a.logger.Error(msg, fields...)
|
||||
}
|
||||
|
||||
type noopCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (n noopCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrBody is an err response body.
|
||||
type ErrBody struct {
|
||||
Code string `json:"code"`
|
||||
Msg string `json:"message"`
|
||||
}
|
|
@ -0,0 +1,309 @@
|
|||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
kithttp "github.com/influxdata/influxdb/kit/transport/http"
|
||||
"github.com/influxdata/influxdb/pkg/testttp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_API(t *testing.T) {
|
||||
t.Run("Decode", func(t *testing.T) {
|
||||
t.Run("valid foo no errors", func(t *testing.T) {
|
||||
expected := validatFoo{
|
||||
Foo: "valid",
|
||||
Bar: 10,
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
var api *kithttp.API // shows it is safe to use a nil value
|
||||
|
||||
var out validatFoo
|
||||
err := api.DecodeJSON(encodeJSON(t, expected), &out)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if expected != out {
|
||||
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("gob", func(t *testing.T) {
|
||||
var out validatFoo
|
||||
err := kithttp.NewAPI().DecodeGob(encodeGob(t, expected), &out)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if expected != out {
|
||||
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unmarshals fine with ok error", func(t *testing.T) {
|
||||
badFoo := validatFoo{
|
||||
Foo: "",
|
||||
Bar: 0,
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
var out validatFoo
|
||||
err := kithttp.NewAPI().DecodeJSON(encodeJSON(t, badFoo), &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected an err")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("gob", func(t *testing.T) {
|
||||
var out validatFoo
|
||||
err := kithttp.NewAPI().DecodeGob(encodeGob(t, badFoo), &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected an err")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unmarshal error", func(t *testing.T) {
|
||||
invalidBody := []byte("[}-{]")
|
||||
|
||||
var out validatFoo
|
||||
err := kithttp.NewAPI().DecodeJSON(bytes.NewReader(invalidBody), &out)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unmarshal err fn wraps unmarshalling error", func(t *testing.T) {
|
||||
t.Run("json", func(t *testing.T) {
|
||||
invalidBody := []byte("[}-{]")
|
||||
|
||||
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
|
||||
|
||||
var out validatFoo
|
||||
err := api.DecodeJSON(bytes.NewReader(invalidBody), &out)
|
||||
expectInfluxdbError(t, errors.EInvalid, err)
|
||||
})
|
||||
|
||||
t.Run("gob", func(t *testing.T) {
|
||||
invalidBody := []byte("[}-{]")
|
||||
|
||||
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
|
||||
|
||||
var out validatFoo
|
||||
err := api.DecodeGob(bytes.NewReader(invalidBody), &out)
|
||||
expectInfluxdbError(t, errors.EInvalid, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ok error fn wraps ok error", func(t *testing.T) {
|
||||
badFoo := validatFoo{Foo: ""}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
|
||||
|
||||
var out validatFoo
|
||||
err := api.DecodeJSON(encodeJSON(t, badFoo), &out)
|
||||
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
|
||||
})
|
||||
|
||||
t.Run("gob", func(t *testing.T) {
|
||||
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
|
||||
|
||||
var out validatFoo
|
||||
err := api.DecodeGob(encodeGob(t, badFoo), &out)
|
||||
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Respond", func(t *testing.T) {
|
||||
tests := []int{
|
||||
http.StatusCreated,
|
||||
http.StatusOK,
|
||||
http.StatusNoContent,
|
||||
http.StatusForbidden,
|
||||
http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
for _, statusCode := range tests {
|
||||
fn := func(t *testing.T) {
|
||||
responder := kithttp.NewAPI()
|
||||
|
||||
svr := func(w http.ResponseWriter, r *http.Request) {
|
||||
responder.Respond(w, r, statusCode, map[string]string{
|
||||
"foo": "bar",
|
||||
})
|
||||
}
|
||||
|
||||
expectedBodyFn := func(body *bytes.Buffer) {
|
||||
var resp map[string]string
|
||||
require.NoError(t, json.NewDecoder(body).Decode(&resp))
|
||||
assert.Equal(t, "bar", resp["foo"])
|
||||
}
|
||||
if statusCode == http.StatusNoContent {
|
||||
expectedBodyFn = func(body *bytes.Buffer) {
|
||||
require.Zero(t, body.Len())
|
||||
}
|
||||
}
|
||||
|
||||
testttp.
|
||||
Get(t, "/foo").
|
||||
Do(http.HandlerFunc(svr)).
|
||||
ExpectStatus(statusCode).
|
||||
ExpectBody(expectedBodyFn)
|
||||
}
|
||||
t.Run(http.StatusText(statusCode), fn)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Err", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
statusCode int
|
||||
expectedErr *errors.Error
|
||||
}{
|
||||
{
|
||||
statusCode: http.StatusBadRequest,
|
||||
expectedErr: &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: "failed to unmarshal",
|
||||
},
|
||||
},
|
||||
{
|
||||
statusCode: http.StatusForbidden,
|
||||
expectedErr: &errors.Error{
|
||||
Code: errors.EForbidden,
|
||||
Msg: "forbidden",
|
||||
},
|
||||
},
|
||||
{
|
||||
statusCode: http.StatusUnprocessableEntity,
|
||||
expectedErr: &errors.Error{
|
||||
Code: errors.EUnprocessableEntity,
|
||||
Msg: "failed validation",
|
||||
},
|
||||
},
|
||||
{
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedErr: &errors.Error{
|
||||
Code: errors.EInternal,
|
||||
Msg: "internal error",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
fn := func(t *testing.T) {
|
||||
responder := kithttp.NewAPI()
|
||||
|
||||
svr := func(w http.ResponseWriter, r *http.Request) {
|
||||
responder.Err(w, r, tt.expectedErr)
|
||||
}
|
||||
|
||||
testttp.
|
||||
Get(t, "/foo").
|
||||
Do(http.HandlerFunc(svr)).
|
||||
ExpectStatus(tt.statusCode).
|
||||
ExpectBody(func(body *bytes.Buffer) {
|
||||
var err kithttp.ErrBody
|
||||
require.NoError(t, json.NewDecoder(body).Decode(&err))
|
||||
assert.Equal(t, tt.expectedErr.Msg, err.Msg)
|
||||
assert.Equal(t, tt.expectedErr.Code, err.Code)
|
||||
})
|
||||
}
|
||||
t.Run(http.StatusText(tt.statusCode), fn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func expectInfluxdbError(t *testing.T, expectedCode string, err error) {
|
||||
t.Helper()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
iErr, ok := err.(*errors.Error)
|
||||
if !ok {
|
||||
t.Fatalf("expected an influxdb error; got=%#v", err)
|
||||
}
|
||||
|
||||
if got := iErr.Code; expectedCode != got {
|
||||
t.Fatalf("unexpected error code; expected=%s got=%s", expectedCode, got)
|
||||
}
|
||||
}
|
||||
|
||||
func encodeGob(t *testing.T, v interface{}) io.Reader {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &buf
|
||||
}
|
||||
|
||||
func encodeJSON(t *testing.T, v interface{}) io.Reader {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &buf
|
||||
}
|
||||
|
||||
func okErrFn(err error) error {
|
||||
return &errors.Error{
|
||||
Code: errors.EUnprocessableEntity,
|
||||
Msg: "failed validation",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalErrFn(encoding string, err error) error {
|
||||
return &errors.Error{
|
||||
Code: errors.EInvalid,
|
||||
Msg: fmt.Sprintf("invalid %s request body", encoding),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type validatFoo struct {
|
||||
Foo string `gob:"foo"`
|
||||
Bar int `gob:"bar"`
|
||||
}
|
||||
|
||||
func (v *validatFoo) OK() error {
|
||||
var errs multiErr
|
||||
if v.Foo == "" {
|
||||
errs = append(errs, "foo must be at least 1 char")
|
||||
}
|
||||
if v.Bar < 0 {
|
||||
errs = append(errs, "bar must be a positive real number")
|
||||
}
|
||||
return errs.toError()
|
||||
}
|
||||
|
||||
type multiErr []string
|
||||
|
||||
func (m multiErr) toError() error {
|
||||
if len(m) > 0 {
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m multiErr) Error() string {
|
||||
return strings.Join(m, "; ")
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
errors2 "github.com/influxdata/influxdb/kit/platform/errors"
|
||||
)
|
||||
|
||||
// ErrorHandler is the error handler in http package.
|
||||
type ErrorHandler int
|
||||
|
||||
// HandleHTTPError encodes err with the appropriate status code and format,
|
||||
// sets the X-Platform-Error-Code headers on the response.
|
||||
// We're no longer using X-Influx-Error and X-Influx-Reference.
|
||||
// and sets the response status to the corresponding status code.
|
||||
func (h ErrorHandler) HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
code := errors2.ErrorCode(err)
|
||||
var msg string
|
||||
if err, ok := err.(*errors2.Error); ok {
|
||||
msg = err.Error()
|
||||
} else {
|
||||
msg = "An internal error has occurred"
|
||||
}
|
||||
|
||||
WriteErrorResponse(ctx, w, code, msg)
|
||||
}
|
||||
|
||||
func WriteErrorResponse(ctx context.Context, w http.ResponseWriter, code string, msg string) {
|
||||
w.Header().Set(PlatformErrorCodeHeader, code)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(ErrorCodeToStatusCode(ctx, code))
|
||||
e := struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
}
|
||||
b, _ := json.Marshal(e)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
// StatusCodeToErrorCode maps a http status code integer to an
|
||||
// influxdb error code string.
|
||||
func StatusCodeToErrorCode(statusCode int) string {
|
||||
errorCode, ok := httpStatusCodeToInfluxDBError[statusCode]
|
||||
if ok {
|
||||
return errorCode
|
||||
}
|
||||
|
||||
return errors2.EInternal
|
||||
}
|
||||
|
||||
// ErrorCodeToStatusCode maps an influxdb error code string to a
|
||||
// http status code integer.
|
||||
func ErrorCodeToStatusCode(ctx context.Context, code string) int {
|
||||
// If the client disconnects early or times out then return a different
|
||||
// error than the passed in error code. Client timeouts return a 408
|
||||
// while disconnections return a non-standard Nginx HTTP 499 code.
|
||||
if err := ctx.Err(); err == context.DeadlineExceeded {
|
||||
return http.StatusRequestTimeout
|
||||
} else if err == context.Canceled {
|
||||
return 499 // https://httpstatuses.com/499
|
||||
}
|
||||
|
||||
// Otherwise map internal error codes to HTTP status codes.
|
||||
statusCode, ok := influxDBErrorToStatusCode[code]
|
||||
if ok {
|
||||
return statusCode
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// influxDBErrorToStatusCode is a mapping of ErrorCode to http status code.
|
||||
var influxDBErrorToStatusCode = map[string]int{
|
||||
errors2.EInternal: http.StatusInternalServerError,
|
||||
errors2.ENotImplemented: http.StatusNotImplemented,
|
||||
errors2.EInvalid: http.StatusBadRequest,
|
||||
errors2.EUnprocessableEntity: http.StatusUnprocessableEntity,
|
||||
errors2.EEmptyValue: http.StatusBadRequest,
|
||||
errors2.EConflict: http.StatusUnprocessableEntity,
|
||||
errors2.ENotFound: http.StatusNotFound,
|
||||
errors2.EUnavailable: http.StatusServiceUnavailable,
|
||||
errors2.EForbidden: http.StatusForbidden,
|
||||
errors2.ETooManyRequests: http.StatusTooManyRequests,
|
||||
errors2.EUnauthorized: http.StatusUnauthorized,
|
||||
errors2.EMethodNotAllowed: http.StatusMethodNotAllowed,
|
||||
errors2.ETooLarge: http.StatusRequestEntityTooLarge,
|
||||
}
|
||||
|
||||
var httpStatusCodeToInfluxDBError = map[int]string{}
|
||||
|
||||
func init() {
|
||||
for k, v := range influxDBErrorToStatusCode {
|
||||
httpStatusCodeToInfluxDBError[v] = k
|
||||
}
|
||||
}
|
||||
|
||||
// CheckErrorStatus for status and any error in the response.
|
||||
func CheckErrorStatus(code int, res *http.Response) error {
|
||||
err := CheckError(res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.StatusCode != code {
|
||||
return fmt.Errorf("unexpected status code: %s", res.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckError reads the http.Response and returns an error if one exists.
|
||||
// It will automatically recognize the errors returned by Influx services
|
||||
// and decode the error into an internal error type. If the error cannot
|
||||
// be determined in that way, it will create a generic error message.
|
||||
//
|
||||
// If there is no error, then this returns nil.
|
||||
func CheckError(resp *http.Response) (err error) {
|
||||
switch resp.StatusCode / 100 {
|
||||
case 4, 5:
|
||||
// We will attempt to parse this error outside of this block.
|
||||
case 2:
|
||||
return nil
|
||||
default:
|
||||
// TODO(jsternberg): Figure out what to do here?
|
||||
return &errors2.Error{
|
||||
Code: errors2.EInternal,
|
||||
Msg: fmt.Sprintf("unexpected status code: %d %s", resp.StatusCode, resp.Status),
|
||||
}
|
||||
}
|
||||
|
||||
perr := &errors2.Error{
|
||||
Code: StatusCodeToErrorCode(resp.StatusCode),
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnsupportedMediaType {
|
||||
perr.Msg = fmt.Sprintf("invalid media type: %q", resp.Header.Get("Content-Type"))
|
||||
return perr
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
// Assume JSON if there is no content-type.
|
||||
contentType = "application/json"
|
||||
}
|
||||
mediatype, _, _ := mime.ParseMediaType(contentType)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, resp.Body); err != nil {
|
||||
perr.Msg = "failed to read error response"
|
||||
perr.Err = err
|
||||
return perr
|
||||
}
|
||||
|
||||
switch mediatype {
|
||||
case "application/json":
|
||||
if err := json.Unmarshal(buf.Bytes(), perr); err != nil {
|
||||
perr.Msg = fmt.Sprintf("attempted to unmarshal error as JSON but failed: %q", err)
|
||||
perr.Err = firstLineAsError(buf)
|
||||
}
|
||||
default:
|
||||
perr.Err = firstLineAsError(buf)
|
||||
}
|
||||
|
||||
if perr.Code == "" {
|
||||
// given it was unset during attempt to unmarshal as JSON
|
||||
perr.Code = StatusCodeToErrorCode(resp.StatusCode)
|
||||
}
|
||||
|
||||
return perr
|
||||
}
|
||||
func firstLineAsError(buf bytes.Buffer) error {
|
||||
line, _ := buf.ReadString('\n')
|
||||
return errors.New(strings.TrimSuffix(line, "\n"))
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package http_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
kithttp "github.com/influxdata/influxdb/kit/transport/http"
|
||||
)
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
kithttp.ErrorHandler(0).HandleHTTPError(ctx, nil, w)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("expected status code 200, got: %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeErrorWithError(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
err := &errors.Error{
|
||||
Code: errors.EInternal,
|
||||
Msg: "an error occurred",
|
||||
Err: fmt.Errorf("there's an error here, be aware"),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
kithttp.ErrorHandler(0).HandleHTTPError(ctx, err, w)
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Errorf("expected status code 500, got: %d", w.Code)
|
||||
}
|
||||
|
||||
errHeader := w.Header().Get("X-Platform-Error-Code")
|
||||
if errHeader != errors.EInternal {
|
||||
t.Errorf("expected X-Platform-Error-Code: %s, got: %s", errors.EInternal, errHeader)
|
||||
}
|
||||
|
||||
// The http handler will flatten the message and it will not
|
||||
// have an error property, so reading the serialization results
|
||||
// in a different error.
|
||||
pe := kithttp.CheckError(w.Result()).(*errors.Error)
|
||||
if want, got := errors.EInternal, pe.Code; want != got {
|
||||
t.Errorf("unexpected code -want/+got:\n\t- %q\n\t+ %q", want, got)
|
||||
}
|
||||
if want, got := "an error occurred: there's an error here, be aware", pe.Msg; want != got {
|
||||
t.Errorf("unexpected message -want/+got:\n\t- %q\n\t+ %q", want, got)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/feature"
|
||||
)
|
||||
|
||||
// Enabler allows the switching between two HTTP Handlers
|
||||
type Enabler interface {
|
||||
Enabled(ctx context.Context, flagger ...feature.Flagger) bool
|
||||
}
|
||||
|
||||
// FeatureHandler is used to switch requests between an existing and a feature flagged
|
||||
// HTTP Handler on a per-request basis
|
||||
type FeatureHandler struct {
|
||||
enabler Enabler
|
||||
flagger feature.Flagger
|
||||
oldHandler http.Handler
|
||||
newHandler http.Handler
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewFeatureHandler(e Enabler, f feature.Flagger, old, new http.Handler, prefix string) *FeatureHandler {
|
||||
return &FeatureHandler{e, f, old, new, prefix}
|
||||
}
|
||||
|
||||
func (h *FeatureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if h.enabler.Enabled(r.Context(), h.flagger) {
|
||||
h.newHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
h.oldHandler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (h *FeatureHandler) Prefix() string {
|
||||
return h.prefix
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ResourceHandler is an HTTP handler for a resource. The prefix
|
||||
// describes the url path prefix that relates to the handler
|
||||
// endpoints.
|
||||
type ResourceHandler interface {
|
||||
Prefix() string
|
||||
http.Handler
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform"
|
||||
"github.com/influxdata/influxdb/kit/platform/errors"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/influxdata/influxdb/kit/tracing"
|
||||
ua "github.com/mileusna/useragent"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Middleware constructor.
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
func SetCORS(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
if origin := r.Header.Get("Origin"); origin != "" {
|
||||
// Access-Control-Allow-Origin must be present in every response
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
// allow and stop processing in pre-flight requests
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, User-Agent")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func Metrics(name string, reqMetric *prometheus.CounterVec, durMetric *prometheus.HistogramVec) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
statusW := NewStatusResponseWriter(w)
|
||||
|
||||
defer func(start time.Time) {
|
||||
label := prometheus.Labels{
|
||||
"handler": name,
|
||||
"method": r.Method,
|
||||
"path": normalizePath(r.URL.Path),
|
||||
"status": statusW.StatusCodeClass(),
|
||||
"response_code": fmt.Sprintf("%d", statusW.Code()),
|
||||
"user_agent": UserAgent(r),
|
||||
}
|
||||
durMetric.With(label).Observe(time.Since(start).Seconds())
|
||||
reqMetric.With(label).Inc()
|
||||
}(time.Now())
|
||||
|
||||
next.ServeHTTP(statusW, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
func SkipOptions(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Preflight CORS requests from the browser will send an options request,
|
||||
// so we need to make sure we satisfy them
|
||||
if origin := r.Header.Get("Origin"); origin == "" && r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func Trace(name string) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
span, r := tracing.ExtractFromHTTPRequest(r, name)
|
||||
defer span.Finish()
|
||||
|
||||
span.LogKV("user_agent", UserAgent(r))
|
||||
for k, v := range r.Header {
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if k == "Authorization" || k == "User-Agent" {
|
||||
continue
|
||||
}
|
||||
|
||||
// If header has multiple values, only the first value will be logged on the trace.
|
||||
span.LogKV(k, v[0])
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
func UserAgent(r *http.Request) string {
|
||||
header := r.Header.Get("User-Agent")
|
||||
if header == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return ua.Parse(header).Name
|
||||
}
|
||||
|
||||
func normalizePath(p string) string {
|
||||
var parts []string
|
||||
for head, tail := shiftPath(p); ; head, tail = shiftPath(tail) {
|
||||
piece := head
|
||||
if len(piece) == platform.IDLength {
|
||||
if _, err := platform.IDFromString(head); err == nil {
|
||||
piece = ":id"
|
||||
}
|
||||
}
|
||||
parts = append(parts, piece)
|
||||
if tail == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return "/" + path.Join(parts...)
|
||||
}
|
||||
|
||||
func shiftPath(p string) (head, tail string) {
|
||||
p = path.Clean("/" + p)
|
||||
i := strings.Index(p[1:], "/") + 1
|
||||
if i <= 0 {
|
||||
return p[1:], "/"
|
||||
}
|
||||
return p[1:i], p[i:]
|
||||
}
|
||||
|
||||
type OrgContext string
|
||||
|
||||
const CtxOrgKey OrgContext = "orgID"
|
||||
|
||||
// ValidResource make sure a resource exists when a sub system needs to be mounted to an api
|
||||
func ValidResource(api *API, lookupOrgByResourceID func(context.Context, platform.ID) (platform.ID, error)) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
statusW := NewStatusResponseWriter(w)
|
||||
id, err := platform.IDFromString(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
api.Err(w, r, platform.ErrCorruptID(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
orgID, err := lookupOrgByResourceID(ctx, *id)
|
||||
if err != nil {
|
||||
// if this function returns an error we will squash the error message and replace it with a not found error
|
||||
api.Err(w, r, &errors.Error{
|
||||
Code: errors.ENotFound,
|
||||
Msg: "404 page not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// embed OrgID into context
|
||||
next.ServeHTTP(statusW, r.WithContext(context.WithValue(ctx, CtxOrgKey, orgID)))
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
// OrgIDFromContext ....
|
||||
func OrgIDFromContext(ctx context.Context) *platform.ID {
|
||||
v := ctx.Value(CtxOrgKey)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
id := v.(platform.ID)
|
||||
return &id
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/kit/platform"
|
||||
|
||||
"github.com/influxdata/influxdb/pkg/testttp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_normalizePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "1",
|
||||
path: path.Join("/api/v2/organizations", platform.ID(2).String()),
|
||||
expected: "/api/v2/organizations/:id",
|
||||
},
|
||||
{
|
||||
name: "2",
|
||||
path: "/api/v2/organizations",
|
||||
expected: "/api/v2/organizations",
|
||||
},
|
||||
{
|
||||
name: "3",
|
||||
path: "/",
|
||||
expected: "/",
|
||||
},
|
||||
{
|
||||
name: "4",
|
||||
path: path.Join("/api/v2/organizations", platform.ID(2).String(), "users", platform.ID(3).String()),
|
||||
expected: "/api/v2/organizations/:id/users/:id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := normalizePath(tt.path)
|
||||
assert.Equal(t, tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCors(t *testing.T) {
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("nextHandler"))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
headers []string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "OPTIONS without Origin",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusMethodNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS with Origin",
|
||||
method: "OPTIONS",
|
||||
headers: []string{"Origin", "http://myapp.com"},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "GET with Origin",
|
||||
method: "GET",
|
||||
headers: []string{"Origin", "http://anotherapp.com"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "http://anotherapp.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svr := SkipOptions(SetCORS(nextHandler))
|
||||
|
||||
testttp.
|
||||
HTTP(t, tt.method, "/", nil).
|
||||
Headers("", "", tt.headers...).
|
||||
Do(svr).
|
||||
ExpectStatus(tt.expectedStatus).
|
||||
ExpectHeaders(tt.expectedHeaders)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type StatusResponseWriter struct {
|
||||
statusCode int
|
||||
responseBytes int
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func NewStatusResponseWriter(w http.ResponseWriter) *StatusResponseWriter {
|
||||
return &StatusResponseWriter{
|
||||
ResponseWriter: w,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StatusResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
w.responseBytes += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader writes the header and captures the status code.
|
||||
func (w *StatusResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *StatusResponseWriter) Code() int {
|
||||
code := w.statusCode
|
||||
if code == 0 {
|
||||
// When statusCode is 0 then WriteHeader was never called and we can assume that
|
||||
// the ResponseWriter wrote an http.StatusOK.
|
||||
code = http.StatusOK
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
func (w *StatusResponseWriter) ResponseBytes() int {
|
||||
return w.responseBytes
|
||||
}
|
||||
|
||||
func (w *StatusResponseWriter) StatusCodeClass() string {
|
||||
class := "XXX"
|
||||
switch w.Code() / 100 {
|
||||
case 1:
|
||||
class = "1XX"
|
||||
case 2:
|
||||
class = "2XX"
|
||||
case 3:
|
||||
class = "3XX"
|
||||
case 4:
|
||||
class = "4XX"
|
||||
case 5:
|
||||
class = "5XX"
|
||||
}
|
||||
return class
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
package testttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Req is a request builder.
|
||||
type Req struct {
|
||||
t testing.TB
|
||||
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
// HTTP runs creates a request for an http call.
|
||||
func HTTP(t testing.TB, method, addr string, body io.Reader) *Req {
|
||||
return &Req{
|
||||
t: t,
|
||||
req: httptest.NewRequest(method, addr, body),
|
||||
}
|
||||
}
|
||||
|
||||
// Delete creates a DELETE request.
|
||||
func Delete(t testing.TB, addr string) *Req {
|
||||
return HTTP(t, http.MethodDelete, addr, nil)
|
||||
}
|
||||
|
||||
// Get creates a GET request.
|
||||
func Get(t testing.TB, addr string) *Req {
|
||||
return HTTP(t, http.MethodGet, addr, nil)
|
||||
}
|
||||
|
||||
// Patch creates a PATCH request.
|
||||
func Patch(t testing.TB, addr string, body io.Reader) *Req {
|
||||
return HTTP(t, http.MethodPatch, addr, body)
|
||||
}
|
||||
|
||||
// PatchJSON creates a PATCH request with a json encoded body.
|
||||
func PatchJSON(t testing.TB, addr string, v interface{}) *Req {
|
||||
return HTTP(t, http.MethodPatch, addr, mustEncodeJSON(t, v))
|
||||
}
|
||||
|
||||
// Post creates a POST request.
|
||||
func Post(t testing.TB, addr string, body io.Reader) *Req {
|
||||
return HTTP(t, http.MethodPost, addr, body)
|
||||
}
|
||||
|
||||
// PostJSON returns a POST request with a json encoded body.
|
||||
func PostJSON(t testing.TB, addr string, v interface{}) *Req {
|
||||
return Post(t, addr, mustEncodeJSON(t, v))
|
||||
}
|
||||
|
||||
// Put creates a PUT request.
|
||||
func Put(t testing.TB, addr string, body io.Reader) *Req {
|
||||
return HTTP(t, http.MethodPut, addr, body)
|
||||
}
|
||||
|
||||
// PutJSON creates a PUT request with a json encoded body.
|
||||
func PutJSON(t testing.TB, addr string, v interface{}) *Req {
|
||||
return HTTP(t, http.MethodPut, addr, mustEncodeJSON(t, v))
|
||||
}
|
||||
|
||||
// Do runs the request against the provided handler.
|
||||
func (r *Req) Do(handler http.Handler) *Resp {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, r.req)
|
||||
|
||||
return &Resp{
|
||||
t: r.t,
|
||||
debug: true,
|
||||
Req: r.req,
|
||||
Rec: rec,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Req) SetFormValue(k, v string) *Req {
|
||||
if r.req.Form == nil {
|
||||
r.req.Form = make(url.Values)
|
||||
}
|
||||
r.req.Form.Set(k, v)
|
||||
return r
|
||||
}
|
||||
|
||||
// Headers allows the user to set headers on the http request.
|
||||
func (r *Req) Headers(k, v string, rest ...string) *Req {
|
||||
headers := append(rest, k, v)
|
||||
for i := 0; i < len(headers); i += 2 {
|
||||
if i+1 >= len(headers) {
|
||||
break
|
||||
}
|
||||
k, v := headers[i], headers[i+1]
|
||||
r.req.Header.Add(k, v)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithCtx sets the ctx on the request.
|
||||
func (r *Req) WithCtx(ctx context.Context) *Req {
|
||||
r.req = r.req.WithContext(ctx)
|
||||
return r
|
||||
}
|
||||
|
||||
// WrapCtx provides means to wrap a request context. This is useful for stuffing in the
|
||||
// auth stuffs that are required at times.
|
||||
func (r *Req) WrapCtx(fn func(ctx context.Context) context.Context) *Req {
|
||||
return r.WithCtx(fn(r.req.Context()))
|
||||
}
|
||||
|
||||
// Resp is a http recorder wrapper.
|
||||
type Resp struct {
|
||||
t testing.TB
|
||||
|
||||
debug bool
|
||||
|
||||
Req *http.Request
|
||||
Rec *httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
// Debug sets the debugger. If true, the debugger will print the body of the response
|
||||
// when the expected status is not received.
|
||||
func (r *Resp) Debug(b bool) *Resp {
|
||||
r.debug = b
|
||||
return r
|
||||
}
|
||||
|
||||
// Expect allows the assertions against the raw Resp.
|
||||
func (r *Resp) Expect(fn func(*Resp)) *Resp {
|
||||
fn(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// ExpectStatus compares the expected status code against the recorded status code.
|
||||
func (r *Resp) ExpectStatus(code int) *Resp {
|
||||
r.t.Helper()
|
||||
|
||||
if r.Rec.Code != code {
|
||||
r.t.Errorf("unexpected status code: expected=%d got=%d", code, r.Rec.Code)
|
||||
if r.debug {
|
||||
r.t.Logf("body: %v", r.Rec.Body.String())
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ExpectBody provides an assertion against the recorder body.
|
||||
func (r *Resp) ExpectBody(fn func(body *bytes.Buffer)) *Resp {
|
||||
fn(r.Rec.Body)
|
||||
return r
|
||||
}
|
||||
|
||||
// ExpectHeaders asserts that multiple headers with values exist in the recorder.
|
||||
func (r *Resp) ExpectHeaders(h map[string]string) *Resp {
|
||||
for k, v := range h {
|
||||
r.ExpectHeader(k, v)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// ExpectHeader asserts that the header is in the recorder.
|
||||
func (r *Resp) ExpectHeader(k, v string) *Resp {
|
||||
r.t.Helper()
|
||||
|
||||
vals, ok := r.Rec.Header()[k]
|
||||
if !ok {
|
||||
r.t.Errorf("did not find expected header: %q", k)
|
||||
return r
|
||||
}
|
||||
|
||||
for _, vv := range vals {
|
||||
if vv == v {
|
||||
return r
|
||||
}
|
||||
}
|
||||
r.t.Errorf("did not find expected value for header %q; got: %v", k, vals)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func mustEncodeJSON(t testing.TB, v interface{}) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &buf
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package testttp_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/influxdb/pkg/testttp"
|
||||
)
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
svr := newMux()
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testttp.
|
||||
Delete(t, "/").
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusNoContent)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
testttp.
|
||||
Get(t, "/").
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusOK).
|
||||
ExpectBody(assertBody(t, http.MethodGet))
|
||||
})
|
||||
|
||||
t.Run("Patch", func(t *testing.T) {
|
||||
testttp.
|
||||
Patch(t, "/", nil).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusPartialContent).
|
||||
ExpectBody(assertBody(t, http.MethodPatch))
|
||||
})
|
||||
|
||||
t.Run("PatchJSON", func(t *testing.T) {
|
||||
testttp.
|
||||
PatchJSON(t, "/", map[string]string{"k": "t"}).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusPartialContent).
|
||||
ExpectBody(assertBody(t, http.MethodPatch))
|
||||
})
|
||||
|
||||
t.Run("Post", func(t *testing.T) {
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
testttp.
|
||||
Post(t, "/", nil).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusCreated).
|
||||
ExpectBody(assertBody(t, http.MethodPost))
|
||||
})
|
||||
|
||||
t.Run("with form values", func(t *testing.T) {
|
||||
svr := http.NewServeMux()
|
||||
svr.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.FormValue("key")))
|
||||
}))
|
||||
|
||||
testttp.
|
||||
Post(t, "/", nil).
|
||||
SetFormValue("key", "val").
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusOK).
|
||||
ExpectBody(func(body *bytes.Buffer) {
|
||||
if expected, got := "val", body.String(); expected != got {
|
||||
t.Fatalf("did not get form value; expected=%q got=%q", expected, got)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("PostJSON", func(t *testing.T) {
|
||||
testttp.
|
||||
PostJSON(t, "/", map[string]string{"k": "v"}).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusCreated).
|
||||
ExpectBody(assertBody(t, http.MethodPost))
|
||||
})
|
||||
|
||||
t.Run("Put", func(t *testing.T) {
|
||||
testttp.
|
||||
Put(t, "/", nil).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusAccepted).
|
||||
ExpectBody(assertBody(t, http.MethodPut))
|
||||
})
|
||||
|
||||
t.Run("PutJSON", func(t *testing.T) {
|
||||
testttp.
|
||||
PutJSON(t, "/", map[string]string{"k": "t"}).
|
||||
Do(svr).
|
||||
ExpectStatus(http.StatusAccepted).
|
||||
ExpectBody(assertBody(t, http.MethodPut))
|
||||
})
|
||||
|
||||
t.Run("Headers", func(t *testing.T) {
|
||||
testttp.
|
||||
Post(t, "/", strings.NewReader(`a: foo`)).
|
||||
Headers("Content-Type", "text/yml").
|
||||
Do(svr).
|
||||
Expect(func(resp *testttp.Resp) {
|
||||
equals(t, "text/yml", resp.Req.Header.Get("Content-Type"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type foo struct {
|
||||
Name, Thing, Method string
|
||||
}
|
||||
|
||||
func newMux() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
writeFn(w, req.Method, http.StatusOK)
|
||||
case http.MethodPost:
|
||||
writeFn(w, req.Method, http.StatusCreated)
|
||||
case http.MethodPut:
|
||||
writeFn(w, req.Method, http.StatusAccepted)
|
||||
case http.MethodPatch:
|
||||
writeFn(w, req.Method, http.StatusPartialContent)
|
||||
case http.MethodDelete:
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
})
|
||||
return mux
|
||||
}
|
||||
|
||||
func assertBody(t *testing.T, method string) func(*bytes.Buffer) {
|
||||
return func(buf *bytes.Buffer) {
|
||||
var f foo
|
||||
if err := json.NewDecoder(buf).Decode(&f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := foo{Name: "name", Thing: "thing", Method: method}
|
||||
equals(t, expected, f)
|
||||
}
|
||||
}
|
||||
|
||||
func writeFn(w http.ResponseWriter, method string, statusCode int) {
|
||||
f := foo{Name: "name", Thing: "thing", Method: method}
|
||||
r, err := encodeBuf(f)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func equals(t *testing.T, expected, actual interface{}) {
|
||||
t.Helper()
|
||||
if expected == actual {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected: %v\tactual: %v", expected, actual)
|
||||
}
|
||||
|
||||
func encodeBuf(v interface{}) (io.Reader, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &buf, nil
|
||||
}
|
Loading…
Reference in New Issue