chore: Add kit (#21086)

* chore: pull in unchanged kit from v2

* chore: remove v2 from import paths

* chore: update module paths and go.mod for kit

* chore: remove kit/cli again, not needed in 1.x
pull/21103/head
Sam Arnold 2021-03-30 14:09:04 -03:00 committed by GitHub
parent fbfd4b4651
commit 78724e5c20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 7073 additions and 2 deletions

6
go.mod
View File

@ -12,10 +12,12 @@ require (
github.com/davecgh/go-spew v1.1.1
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1
github.com/dgryski/go-bitstream v0.0.0-20180413035011-3522498ce2c8
github.com/go-chi/chi v4.1.0+incompatible
github.com/gogo/protobuf v1.3.1
github.com/golang/snappy v0.0.1
github.com/google/go-cmp v0.5.0
github.com/influxdata/flux v0.111.0
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69
github.com/influxdata/influxql v1.1.1-0.20210223160523-b6ab99450c93
github.com/influxdata/pkg-config v0.2.7
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6
@ -24,16 +26,20 @@ require (
github.com/jwilder/encoding v0.0.0-20170811194829-b4e1701a28ef
github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada
github.com/mattn/go-isatty v0.0.12
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5
github.com/opentracing/opentracing-go v1.1.0
github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.5.1
github.com/prometheus/client_model v0.2.0
github.com/prometheus/common v0.9.1
github.com/prometheus/prometheus v0.0.0-20200609090129-a6600f564e3c
github.com/retailnext/hllpp v1.0.1-0.20180308014038-101a6d2f8b52
github.com/spf13/cast v1.3.0
github.com/spf13/cobra v0.0.3
github.com/stretchr/testify v1.5.1
github.com/tinylib/msgp v1.1.0
github.com/uber/jaeger-client-go v2.23.0+incompatible
github.com/willf/bitset v1.1.9 // indirect
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6
go.uber.org/zap v1.14.1

16
go.sum
View File

@ -152,6 +152,7 @@ github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w=
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
@ -199,6 +200,7 @@ github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15 h1:nLPjjvpUAODOR
github.com/foxcpp/go-mockdns v0.0.0-20201212160233-ede2f9158d15/go.mod h1:tPg4cp4nseejPd+UKxtCVQ2hUxNTZ7qQZJa7CLriIeo=
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
@ -208,6 +210,8 @@ github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd h1:r04M
github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE=
github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31 h1:gclg6gY70GLy3PbkQ1AERPfmLMMagS60DKF78eWwLn8=
github.com/glycerine/goconvey v0.0.0-20190410193231-58a59202ab31/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24=
github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w=
github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
@ -434,6 +438,8 @@ github.com/influxdata/flux v0.65.0 h1:57tk1Oo4gpGIDbV12vUAPCMtLtThhaXzub1XRIuqv6
github.com/influxdata/flux v0.65.0/go.mod h1:BwN2XG2lMszOoquQaFdPET8FRQfrXiZsWmcMO9rkaVY=
github.com/influxdata/flux v0.111.0 h1:27CNz0SbEofD9NzdwcdxRwGmuVSDSisVq4dOceB/KF0=
github.com/influxdata/flux v0.111.0/go.mod h1:3TJtvbm/Kwuo5/PEo5P6HUzwVg4bXWkb2wPQHPtQdlU=
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69 h1:WQsmW0fXO4ZE/lFGIE84G6rIV5SJN3P3sjIXAP1a8eU=
github.com/influxdata/httprouter v1.3.1-0.20191122104820-ee83e2772f69/go.mod h1:pwymjR6SrP3gD3pRj9RJwdl1j5s3doEEV8gS4X9qSzA=
github.com/influxdata/influxdb v1.8.0/go.mod h1:SIzcnsjaHRFpmlxpJ4S3NT64qtEKYweNTUMb/vh0OMQ=
github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
github.com/influxdata/influxql v1.1.0/go.mod h1:KpVI7okXjK6PRi3Z5B+mtKZli+R1DnZgb3N+tzevNgo=
@ -447,7 +453,6 @@ github.com/influxdata/pkg-config v0.2.7/go.mod h1:EMS7Ll0S4qkzDk53XS3Z72/egBsPIn
github.com/influxdata/promql/v2 v2.12.0/go.mod h1:fxOPu+DY0bqCTCECchSRtWfc+0X19ybifQhZoQNF5D8=
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6 h1:UzJnB7VRL4PSkUJHwsyzseGOmrO/r4yA+AuxGJxiZmA=
github.com/influxdata/roaring v0.4.13-0.20180809181101-fc520f41fab6/go.mod h1:bSgUQ7q5ZLSO+bKBGqJiCBGAl+9DxyW63zLTujjUlOE=
github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9 h1:MHTrDWmQpHq/hkq+7cw9oYAt2PqUw52TZazRA0N7PGE=
github.com/influxdata/tdigest v0.0.0-20181121200506-bf2b5ad3c0a9/go.mod h1:Js0mqiSBE6Ffsg94weZZ2c+v/ciT8QRHFOap7EKDrR0=
github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b h1:i44CesU68ZBRvtCjBi3QSosCIKrjmMbYlQMFAwVLds4=
github.com/influxdata/tdigest v0.0.2-0.20210216194612-fc98d27c9e8b/go.mod h1:Z0kXnxzbTC2qrx4NaIzYkE1k66+6oEDQTvL95hQFh5Y=
@ -540,6 +545,8 @@ github.com/miekg/dns v1.1.22/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg=
github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5 h1:pXqZHmHOz6LN+zbbUgqyGgAWRnnZEI40IzG3tMsXcSI=
github.com/mileusna/useragent v0.0.0-20190129205925-3e331f0949a5/go.mod h1:JWhYAp2EXqUtsxTKdeGlY8Wp44M7VxThC9FEoNGi2IE=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
@ -550,6 +557,7 @@ github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS4
github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY=
github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.2.2 h1:dxe5oCinTXiTIcfgmZecdCzPmAJKd46KsCWc35r0TV4=
github.com/mitchellh/mapstructure v1.2.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -596,9 +604,9 @@ github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/paulbellamy/ratecounter v0.2.0 h1:2L/RhJq+HA8gBQImDXtLPrDXK5qAj6ozWVK/zFXVJGs=
github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChlfo5C6hzIHwPqfFE=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfSeg=
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
@ -690,6 +698,7 @@ github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4k
github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc=
github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
@ -722,7 +731,9 @@ github.com/uber-go/tally v3.3.15+incompatible h1:9hLSgNBP28CjIaDmAuRTq9qV+UZY+9P
github.com/uber-go/tally v3.3.15+incompatible/go.mod h1:YDTIBxdXyOU/sCWilKB4bgyufu1cEi0jdVnRdxvjnmU=
github.com/uber/athenadriver v1.1.4 h1:k6k0RBeXjR7oZ8NO557MsRw3eX1cc/9B0GNx+W9eHiQ=
github.com/uber/athenadriver v1.1.4/go.mod h1:tQjho4NzXw55LGfSZEcETuYydpY1vtmixUabHkC1K/E=
github.com/uber/jaeger-client-go v2.23.0+incompatible h1:o2g11IUBdEsSZVzF3k7+bahLmxRP/dbOoW4zQ30UlKE=
github.com/uber/jaeger-client-go v2.23.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw=
github.com/uber/jaeger-lib v2.2.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
@ -1101,6 +1112,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

238
kit/check/check.go Normal file
View File

@ -0,0 +1,238 @@
// Package check standardizes /health and /ready endpoints.
// This allows you to easily know when your server is ready and healthy.
package check
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"sync"
)
// Status string to indicate the overall status of the check.
type Status string
const (
// StatusFail indicates a specific check has failed.
StatusFail Status = "fail"
// StatusPass indicates a specific check has passed.
StatusPass Status = "pass"
// DefaultCheckName is the name of the default checker.
DefaultCheckName = "internal"
)
// Check wraps a map of service names to status checkers.
type Check struct {
healthChecks []Checker
readyChecks []Checker
healthOverride override
readyOverride override
passthroughHandler http.Handler
}
// Checker indicates a service whose health can be checked.
type Checker interface {
Check(ctx context.Context) Response
}
// NewCheck returns a Health with a default checker.
func NewCheck() *Check {
ch := &Check{}
ch.healthOverride.disable()
ch.readyOverride.disable()
return ch
}
// AddHealthCheck adds the check to the list of ready checks.
// If c is a NamedChecker, the name will be added.
func (c *Check) AddHealthCheck(check Checker) {
if nc, ok := check.(NamedChecker); ok {
c.healthChecks = append(c.healthChecks, Named(nc.CheckName(), nc))
} else {
c.healthChecks = append(c.healthChecks, check)
}
}
// AddReadyCheck adds the check to the list of ready checks.
// If c is a NamedChecker, the name will be added.
func (c *Check) AddReadyCheck(check Checker) {
if nc, ok := check.(NamedChecker); ok {
c.readyChecks = append(c.readyChecks, Named(nc.CheckName(), nc))
} else {
c.readyChecks = append(c.readyChecks, check)
}
}
// CheckHealth evaluates c's set of health checks and returns a populated Response.
func (c *Check) CheckHealth(ctx context.Context) Response {
response := Response{
Name: "Health",
Status: StatusPass,
Checks: make(Responses, len(c.healthChecks)),
}
status, overriding := c.healthOverride.get()
if overriding {
response.Status = status
overrideResponse := Response{
Name: "manual-override",
Message: "health manually overridden",
}
response.Checks = append(response.Checks, overrideResponse)
}
for i, ch := range c.healthChecks {
resp := ch.Check(ctx)
if resp.Status != StatusPass && !overriding {
response.Status = resp.Status
}
response.Checks[i] = resp
}
sort.Sort(response.Checks)
return response
}
// CheckReady evaluates c's set of ready checks and returns a populated Response.
func (c *Check) CheckReady(ctx context.Context) Response {
response := Response{
Name: "Ready",
Status: StatusPass,
Checks: make(Responses, len(c.readyChecks)),
}
status, overriding := c.readyOverride.get()
if overriding {
response.Status = status
overrideResponse := Response{
Name: "manual-override",
Message: "ready manually overridden",
}
response.Checks = append(response.Checks, overrideResponse)
}
for i, c := range c.readyChecks {
resp := c.Check(ctx)
if resp.Status != StatusPass && !overriding {
response.Status = resp.Status
}
response.Checks[i] = resp
}
sort.Sort(response.Checks)
return response
}
// SetPassthrough allows you to set a handler to use if the request is not a ready or health check.
// This can be useful if you intend to use this as a middleware.
func (c *Check) SetPassthrough(h http.Handler) {
c.passthroughHandler = h
}
// ServeHTTP serves /ready and /health requests with the respective checks.
func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) {
const (
pathReady = "/ready"
pathHealth = "/health"
queryForce = "force"
)
path := r.URL.Path
// Allow requests not intended for checks to pass through.
if path != pathReady && path != pathHealth {
if c.passthroughHandler != nil {
c.passthroughHandler.ServeHTTP(w, r)
return
}
// We can't handle this request.
w.WriteHeader(http.StatusNotFound)
return
}
ctx := r.Context()
query := r.URL.Query()
switch path {
case pathReady:
switch query.Get(queryForce) {
case "true":
switch query.Get("ready") {
case "true":
c.readyOverride.enable(StatusPass)
case "false":
c.readyOverride.enable(StatusFail)
}
case "false":
c.readyOverride.disable()
}
writeResponse(w, c.CheckReady(ctx))
case pathHealth:
switch query.Get(queryForce) {
case "true":
switch query.Get("healthy") {
case "true":
c.healthOverride.enable(StatusPass)
case "false":
c.healthOverride.enable(StatusFail)
}
case "false":
c.healthOverride.disable()
}
writeResponse(w, c.CheckHealth(ctx))
}
}
// writeResponse writes a Response to the wire as JSON. The HTTP status code
// accompanying the payload is the primary means for signaling the status of the
// checks. The possible status codes are:
//
// - 200 OK: All checks pass.
// - 503 Service Unavailable: Some checks are failing.
// - 500 Internal Server Error: There was a problem serializing the Response.
func writeResponse(w http.ResponseWriter, resp Response) {
status := http.StatusOK
if resp.Status == StatusFail {
status = http.StatusServiceUnavailable
}
msg, err := json.MarshalIndent(resp, "", " ")
if err != nil {
msg = []byte(`{"message": "error marshaling response", "status": "fail"}`)
status = http.StatusInternalServerError
}
w.WriteHeader(status)
fmt.Fprintln(w, string(msg))
}
// override is a manual override for an entire group of checks.
type override struct {
mtx sync.Mutex
status Status
active bool
}
// get returns the Status of an override as well as whether or not an override
// is currently active.
func (m *override) get() (Status, bool) {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.status, m.active
}
// disable disables the override.
func (m *override) disable() {
m.mtx.Lock()
m.active = false
m.status = StatusFail
m.mtx.Unlock()
}
// enable turns on the override and establishes a specific Status for which to.
func (m *override) enable(s Status) {
m.mtx.Lock()
m.active = true
m.status = s
m.mtx.Unlock()
}

535
kit/check/check_test.go Normal file
View File

@ -0,0 +1,535 @@
package check
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
)
func TestEmptyCheck(t *testing.T) {
c := NewCheck()
resp := c.CheckReady(context.Background())
if len(resp.Checks) > 0 {
t.Errorf("no checks added but %d returned", len(resp.Checks))
}
if resp.Name != "Ready" {
t.Errorf("expected: \"Ready\", got: %q", resp.Name)
}
if resp.Status != StatusPass {
t.Errorf("expected: %q, got: %q", StatusPass, resp.Status)
}
}
func TestAddHealthCheck(t *testing.T) {
h := NewCheck()
h.AddHealthCheck(Named("awesome", ErrCheck(func() error {
return nil
})))
r := h.CheckHealth(context.Background())
if r.Status != StatusPass {
t.Error("Health should fail because one of the check is unhealthy")
}
if len(r.Checks) != 1 {
t.Fatalf("check not in results: %+v", r.Checks)
}
v := r.Checks[0]
if v.Status != StatusPass {
t.Errorf("the added check should be pass not %q.", v.Status)
}
}
func TestAddUnHealthyCheck(t *testing.T) {
h := NewCheck()
h.AddHealthCheck(Named("failure", ErrCheck(func() error {
return errors.New("Oops! I am sorry")
})))
r := h.CheckHealth(context.Background())
if r.Status != StatusFail {
t.Error("Health should fail because one of the check is unhealthy")
}
if len(r.Checks) != 1 {
t.Fatal("check not in results")
}
v := r.Checks[0]
if v.Status != StatusFail {
t.Errorf("the added check should be fail not %s.", v.Status)
}
if v.Message != "Oops! I am sorry" {
t.Errorf(
"the error should be 'Oops! I am sorry' not %s.",
v.Message,
)
}
}
func buildCheckWithServer() (*Check, *httptest.Server) {
c := NewCheck()
return c, httptest.NewServer(c)
}
type mockCheck struct {
status Status
name string
}
func (m mockCheck) Check(_ context.Context) Response {
return Response{
Name: m.name,
Status: m.status,
}
}
func mockPass(name string) Checker {
return mockCheck{status: StatusPass, name: name}
}
func mockFail(name string) Checker {
return mockCheck{status: StatusFail, name: name}
}
func respBuilder(body io.ReadCloser) (*Response, error) {
defer body.Close()
d := json.NewDecoder(body)
r := &Response{}
return r, d.Decode(r)
}
func TestBasicHTTPHandler(t *testing.T) {
_, ts := buildCheckWithServer()
defer ts.Close()
resp, err := http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Ready",
Status: StatusPass,
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestHealthSorting(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddHealthCheck(mockPass("a"))
c.AddHealthCheck(mockPass("c"))
c.AddHealthCheck(mockPass("b"))
c.AddHealthCheck(mockFail("k"))
c.AddHealthCheck(mockFail("b"))
resp, err := http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Health",
Status: "fail",
Checks: Responses{
Response{Name: "b", Status: "fail"},
Response{Name: "k", Status: "fail"},
Response{Name: "a", Status: "pass"},
Response{Name: "b", Status: "pass"},
Response{Name: "c", Status: "pass"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestForceHealthy(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddHealthCheck(mockFail("a"))
_, err := http.Get(ts.URL + "/health?force=true&healthy=true")
if err != nil {
t.Fatal(err)
}
resp, err := http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Health",
Status: "pass",
Checks: Responses{
Response{Name: "manual-override", Message: "health manually overridden"},
Response{Name: "a", Status: "fail"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
_, err = http.Get(ts.URL + "/health?force=false")
if err != nil {
t.Fatal(err)
}
expected = &Response{
Name: "Health",
Status: "fail",
Checks: Responses{
Response{Name: "a", Status: "fail"},
},
}
resp, err = http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err = respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestForceUnhealthy(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddHealthCheck(mockPass("a"))
_, err := http.Get(ts.URL + "/health?force=true&healthy=false")
if err != nil {
t.Fatal(err)
}
resp, err := http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Health",
Status: "fail",
Checks: Responses{
Response{Name: "manual-override", Message: "health manually overridden"},
Response{Name: "a", Status: "pass"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
_, err = http.Get(ts.URL + "/health?force=false")
if err != nil {
t.Fatal(err)
}
expected = &Response{
Name: "Health",
Status: "pass",
Checks: Responses{
Response{Name: "a", Status: "pass"},
},
}
resp, err = http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err = respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestForceReady(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddReadyCheck(mockFail("a"))
_, err := http.Get(ts.URL + "/ready?force=true&ready=true")
if err != nil {
t.Fatal(err)
}
resp, err := http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Ready",
Status: "pass",
Checks: Responses{
Response{Name: "manual-override", Message: "ready manually overridden"},
Response{Name: "a", Status: "fail"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
_, err = http.Get(ts.URL + "/ready?force=false")
if err != nil {
t.Fatal(err)
}
expected = &Response{
Name: "Ready",
Status: "fail",
Checks: Responses{
Response{Name: "a", Status: "fail"},
},
}
resp, err = http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err = respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestForceNotReady(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddReadyCheck(mockPass("a"))
_, err := http.Get(ts.URL + "/ready?force=true&ready=false")
if err != nil {
t.Fatal(err)
}
resp, err := http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Ready",
Status: "fail",
Checks: Responses{
Response{Name: "manual-override", Message: "ready manually overridden"},
Response{Name: "a", Status: "pass"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
_, err = http.Get(ts.URL + "/ready?force=false")
if err != nil {
t.Fatal(err)
}
expected = &Response{
Name: "Ready",
Status: "pass",
Checks: Responses{
Response{Name: "a", Status: "pass"},
},
}
resp, err = http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err = respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestNoCrossOver(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
c.AddHealthCheck(mockPass("a"))
c.AddHealthCheck(mockPass("c"))
c.AddReadyCheck(mockPass("b"))
c.AddReadyCheck(mockFail("k"))
c.AddHealthCheck(mockFail("b"))
resp, err := http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
actual, err := respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected := &Response{
Name: "Health",
Status: "fail",
Checks: Responses{
Response{Name: "b", Status: "fail"},
Response{Name: "a", Status: "pass"},
Response{Name: "c", Status: "pass"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
resp, err = http.Get(ts.URL + "/ready")
if err != nil {
t.Fatal(err)
}
actual, err = respBuilder(resp.Body)
if err != nil {
t.Fatal(err)
}
expected = &Response{
Name: "Ready",
Status: "fail",
Checks: Responses{
Response{Name: "k", Status: "fail"},
Response{Name: "b", Status: "pass"},
},
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("unexpected response. expected %v, actual %v", expected, actual)
}
}
func TestPassthrough(t *testing.T) {
c, ts := buildCheckWithServer()
defer ts.Close()
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 404 {
t.Fatalf("failed to error when no passthrough is present, status: %d", resp.StatusCode)
}
used := false
s := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used = true
w.Write([]byte("hi"))
})
c.SetPassthrough(s)
resp, err = http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Fatalf("bad response code from passthrough, status: %d", resp.StatusCode)
}
if !used {
t.Fatal("passthrough server not used")
}
}
func ExampleNewCheck() {
// Run the default healthcheck. it always return 200. It is good if you
// have a service without any dependency
h := NewCheck()
h.CheckHealth(context.Background())
}
func ExampleCheck_CheckHealth() {
h := NewCheck()
h.AddHealthCheck(Named("google", CheckerFunc(func(ctx context.Context) Response {
var r net.Resolver
_, err := r.LookupHost(ctx, "google.com")
if err != nil {
return Error(err)
}
return Pass()
})))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
h.CheckHealth(ctx)
}
func ExampleCheck_ServeHTTP() {
c := NewCheck()
http.ListenAndServe(":6060", c)
}
func ExampleCheck_SetPassthrough() {
c := NewCheck()
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello friends!"))
})
c.SetPassthrough(http.DefaultServeMux)
http.ListenAndServe(":6060", c)
}

73
kit/check/helpers.go Normal file
View File

@ -0,0 +1,73 @@
package check
import (
"context"
"fmt"
)
// NamedChecker is a superset of Checker that also indicates the name of the service.
// Prefer to implement NamedChecker if your service has a fixed name,
// as opposed to calling *Health.AddNamed.
type NamedChecker interface {
Checker
CheckName() string
}
// CheckerFunc is an adapter of a plain func() error to the Checker interface.
type CheckerFunc func(ctx context.Context) Response
// Check implements Checker.
func (f CheckerFunc) Check(ctx context.Context) Response {
return f(ctx)
}
// Named returns a Checker that will attach a name to the Response from the check.
// This way, it is possible to augment a Response with a human-readable name, but not have to encode
// that logic in the actual check itself.
func Named(name string, checker Checker) Checker {
return CheckerFunc(func(ctx context.Context) Response {
resp := checker.Check(ctx)
resp.Name = name
return resp
})
}
// NamedFunc is the same as Named except it takes a CheckerFunc.
func NamedFunc(name string, fn CheckerFunc) Checker {
return Named(name, fn)
}
// ErrCheck will create a health checker that executes a function. If the function returns an error,
// it will return an unhealthy response. Otherwise, it will be as if the Ok function was called.
// Note: it is better to use CheckFunc, because with Check, the context is ignored.
func ErrCheck(fn func() error) Checker {
return CheckerFunc(func(_ context.Context) Response {
if err := fn(); err != nil {
return Error(err)
}
return Pass()
})
}
// Pass is a utility function to generate a passing status response with the default parameters.
func Pass() Response {
return Response{
Status: StatusPass,
}
}
// Info is a utility function to generate a healthy status with a printf message.
func Info(msg string, args ...interface{}) Response {
return Response{
Status: StatusPass,
Message: fmt.Sprintf(msg, args...),
}
}
// Error is a utility function for creating a response from an error message.
func Error(err error) Response {
return Response{
Status: StatusFail,
Message: err.Error(),
}
}

39
kit/check/response.go Normal file
View File

@ -0,0 +1,39 @@
package check
// Response is a result of a collection of health checks.
type Response struct {
Name string `json:"name"`
Status Status `json:"status"`
Message string `json:"message,omitempty"`
Checks Responses `json:"checks,omitempty"`
}
// HasCheck verifies whether the receiving Response has a check with the given name or not.
func (r *Response) HasCheck(name string) bool {
found := false
for _, check := range r.Checks {
if check.Name == name {
found = true
break
}
}
return found
}
// Responses is a sortable collection of Response objects.
type Responses []Response
func (r Responses) Len() int { return len(r) }
// Less defines the order in which responses are sorted.
//
// Failing responses are always sorted before passing responses. Responses with
// the same status are then sorted according to the name of the check.
func (r Responses) Less(i, j int) bool {
if r[i].Status == r[j].Status {
return r[i].Name < r[j].Name
}
return r[i].Status < r[j].Status
}
func (r Responses) Swap(i, j int) { r[i], r[j] = r[j], r[i] }

107
kit/errors/errors.go Normal file
View File

@ -0,0 +1,107 @@
package errors
import (
"fmt"
"net/http"
)
// TODO: move to base directory
const (
// InternalError indicates an unexpected error condition.
InternalError = 1
// MalformedData indicates malformed input, such as unparsable JSON.
MalformedData = 2
// InvalidData indicates that data is well-formed, but invalid.
InvalidData = 3
// Forbidden indicates a forbidden operation.
Forbidden = 4
// NotFound indicates a resource was not found.
NotFound = 5
)
// Error indicates an error with a reference code and an HTTP status code.
type Error struct {
Reference int `json:"referenceCode"`
Code int `json:"statusCode"`
Err string `json:"err"`
}
// Error implements the error interface.
func (e Error) Error() string {
return e.Err
}
// Errorf constructs an Error with the given reference code and format.
func Errorf(ref int, format string, i ...interface{}) error {
return Error{
Reference: ref,
Err: fmt.Sprintf(format, i...),
}
}
// New creates a new error with a message and error code.
func New(msg string, ref ...int) error {
refCode := InternalError
if len(ref) == 1 {
refCode = ref[0]
}
return Error{
Reference: refCode,
Err: msg,
}
}
func Wrap(err error, msg string, ref ...int) error {
if err == nil {
return nil
}
e, ok := err.(Error)
if ok {
refCode := e.Reference
if len(ref) == 1 {
refCode = ref[0]
}
return Error{
Reference: refCode,
Code: e.Code,
Err: fmt.Sprintf("%s: %s", msg, e.Err),
}
}
refCode := InternalError
if len(ref) == 1 {
refCode = ref[0]
}
return Error{
Reference: refCode,
Err: fmt.Sprintf("%s: %s", msg, err.Error()),
}
}
// InternalErrorf constructs an InternalError with the given format.
func InternalErrorf(format string, i ...interface{}) error {
return Errorf(InternalError, format, i...)
}
// MalformedDataf constructs a MalformedData error with the given format.
func MalformedDataf(format string, i ...interface{}) error {
return Errorf(MalformedData, format, i...)
}
// InvalidDataf constructs an InvalidData error with the given format.
func InvalidDataf(format string, i ...interface{}) error {
return Errorf(InvalidData, format, i...)
}
// Forbiddenf constructs a Forbidden error with the given format.
func Forbiddenf(format string, i ...interface{}) error {
return Errorf(Forbidden, format, i...)
}
func BadRequestError(msg string) error {
return Error{
Reference: InvalidData,
Code: http.StatusBadRequest,
Err: msg,
}
}

59
kit/errors/list.go Normal file
View File

@ -0,0 +1,59 @@
package errors
import (
"errors"
"strings"
)
// List represents a list of errors.
type List struct {
errs []error
err error // cached error
}
// Append adds err to the errors list.
func (l *List) Append(err error) {
l.errs = append(l.errs, err)
l.err = nil
}
// AppendString adds a new error that formats as the given text.
func (l *List) AppendString(text string) {
l.errs = append(l.errs, errors.New(text))
l.err = nil
}
// Clear removes all the previously appended errors from the list.
func (l *List) Clear() {
for i := range l.errs {
l.errs[i] = nil
}
l.errs = l.errs[:0]
l.err = nil
}
// Err returns an error composed of the list of errors, separated by a new line, or nil if no errors
// were appended.
func (l *List) Err() error {
if len(l.errs) == 0 {
return nil
}
if l.err != nil {
switch len(l.errs) {
case 1:
l.err = l.errs[0]
default:
var sb strings.Builder
sb.WriteString(l.errs[0].Error())
for _, err := range l.errs[1:] {
sb.WriteRune('\n')
sb.WriteString(err.Error())
}
l.err = errors.New(sb.String())
}
}
return l.err
}

View File

@ -0,0 +1,271 @@
package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"io/ioutil"
"os"
"strings"
"text/template"
"github.com/Masterminds/sprig"
"github.com/influxdata/influxdb/kit/feature"
yaml "gopkg.in/yaml.v2"
)
const tmpl = `// Code generated by the feature package; DO NOT EDIT.
package feature
{{ .Qualify | import }}
{{ range $_, $flag := .Flags }}
var {{ $flag.Key }} = {{ $.Qualify | package }}{{ $flag.Default | maker }}(
{{ $flag.Name | quote }},
{{ $flag.Key | quote }},
{{ $flag.Contact | quote }},
{{ $flag.Default | conditionalQuote }},
{{ $.Qualify | package }}{{ $flag.Lifetime | lifetime }},
{{ $flag.Expose }},
)
// {{ $flag.Name | replace " " "_" | camelcase }} - {{ $flag.Description }}
func {{ $flag.Name | replace " " "_" | camelcase }}() {{ $.Qualify | package }}{{ $flag.Default | flagType }} {
return {{ $flag.Key }}
}
{{ end }}
var all = []{{ .Qualify | package }}Flag{
{{ range $_, $flag := .Flags }} {{ $flag.Key }},
{{ end }}}
var byKey = map[string]{{ $.Qualify | package }}Flag{
{{ range $_, $flag := .Flags }} {{ $flag.Key | quote }}: {{ $flag.Key }},
{{ end }}}
`
type flagConfig struct {
Name string
Description string
Key string
Default interface{}
Contact string
Lifetime feature.Lifetime
Expose bool
}
func (f flagConfig) Valid() error {
var problems []string
if f.Key == "" {
problems = append(problems, "missing key")
}
if f.Contact == "" {
problems = append(problems, "missing contact")
}
if f.Default == nil {
problems = append(problems, "missing default")
}
if f.Description == "" {
problems = append(problems, "missing description")
}
if len(problems) > 0 {
name := f.Name
if name == "" {
if f.Key != "" {
name = f.Key
} else {
name = "anonymous"
}
}
// e.g. "my flag: missing key; missing default"
return fmt.Errorf("%s: %s\n", name, strings.Join(problems, "; "))
}
return nil
}
type flagValidationError struct {
errs []error
}
func newFlagValidationError(errs []error) *flagValidationError {
if len(errs) == 0 {
return nil
}
return &flagValidationError{errs}
}
func (e *flagValidationError) Error() string {
var s strings.Builder
s.WriteString("flag validation error: \n")
for _, err := range e.errs {
s.WriteString(err.Error())
}
return s.String()
}
func validate(flags []flagConfig) error {
var (
errs []error
seen = make(map[string]bool, len(flags))
)
for _, flag := range flags {
if err := flag.Valid(); err != nil {
errs = append(errs, err)
} else if _, repeated := seen[flag.Key]; repeated {
errs = append(errs, fmt.Errorf("duplicate flag key '%s'\n", flag.Key))
}
seen[flag.Key] = true
}
if len(errs) != 0 {
return newFlagValidationError(errs)
}
return nil
}
var argv = struct {
in, out *string
qualify *bool
}{
in: flag.String("in", "", "flag configuration path"),
out: flag.String("out", "", "flag generation destination path"),
qualify: flag.Bool("qualify", false, "qualify types with imported package name"),
}
func main() {
if err := run(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
os.Exit(0)
}
func run() error {
flag.Parse()
in, err := os.Open(*argv.in)
if err != nil {
return err
}
defer in.Close()
configuration, err := ioutil.ReadAll(in)
if err != nil {
return err
}
var flags []flagConfig
err = yaml.Unmarshal(configuration, &flags)
if err != nil {
return err
}
err = validate(flags)
if err != nil {
return err
}
t, err := template.New("flags").Funcs(templateFunctions()).Parse(tmpl)
if err != nil {
return err
}
out, err := os.Create(*argv.out)
if err != nil {
return err
}
defer out.Close()
var (
buf = new(bytes.Buffer)
vars = struct {
Qualify bool
Flags []flagConfig
}{
Qualify: *argv.qualify,
Flags: flags,
}
)
if err := t.Execute(buf, vars); err != nil {
return err
}
raw, err := ioutil.ReadAll(buf)
if err != nil {
return err
}
formatted, err := format.Source(raw)
if err != nil {
return err
}
_, err = out.Write(formatted)
return err
}
func templateFunctions() template.FuncMap {
functions := sprig.TxtFuncMap()
functions["lifetime"] = func(t interface{}) string {
switch t {
case feature.Permanent:
return "Permanent"
default:
return "Temporary"
}
}
functions["conditionalQuote"] = func(t interface{}) string {
switch t.(type) {
case string:
return fmt.Sprintf("%q", t)
default:
return fmt.Sprintf("%v", t)
}
}
functions["flagType"] = func(t interface{}) string {
switch t.(type) {
case bool:
return "BoolFlag"
case float64:
return "FloatFlag"
case int:
return "IntFlag"
default:
return "StringFlag"
}
}
functions["maker"] = func(t interface{}) string {
switch t.(type) {
case bool:
return "MakeBoolFlag"
case float64:
return "MakeFloatFlag"
case int:
return "MakeIntFlag"
default:
return "MakeStringFlag"
}
}
functions["package"] = func(t interface{}) string {
if t.(bool) {
return "feature."
}
return ""
}
functions["import"] = func(t interface{}) string {
if t.(bool) {
return "import \"github.com/influxdata/influxdb/kit/feature\""
}
return ""
}
return functions
}

75
kit/feature/doc.go Normal file
View File

@ -0,0 +1,75 @@
// Package feature provides feature flagging capabilities for InfluxDB servers.
// This document describes this package and how it is used to control
// experimental features in `influxd`.
//
// Flags are configured in `flags.yml` at the top of this repository.
// Running `make flags` generates Go code based on this configuration
// to programmatically test flag values in a given request context.
// Boolean flags are the most common case, but integers, floats and
// strings are supported for more complicated experiments.
//
// The `Flagger` interface is the crux of this package.
// It computes a map of feature flag values for a given request context.
// The default implementation always returns the flag default configured
// in `flags.yml`. The override implementation allows an operator to
// override feature flag defaults at startup. Changing these overrides
// requires a restart.
//
// In `influxd`, a `Flagger` instance is provided to a `Handler` middleware
// configured to intercept all API requests and annotate their request context
// with a map of feature flags.
//
// A flag can opt in to be exposed externally in `flags.yml`. If exposed,
// this flag will be included in the response from the `/api/v2/flags`
// endpoint. This allows the UI and other API clients to control their
// behavior according to the flag in addition to the server itself.
//
// A concrete example to illustrate the above:
//
// I have a feature called "My Feature" that will involve turning on new code
// in both the UI and the server.
//
// First, I add an entry to `flags.yml`.
//
// ```yaml
// - name: My Feature
// description: My feature is awesome
// key: myFeature
// default: false
// expose: true
// contact: My Name
// ```
//
// My flag type is inferred to be boolean by my default of `false` when I run
// `make flags` and the `feature` package now includes `func MyFeature() BoolFlag`.
//
// I use this to control my backend code with
//
// ```go
// if feature.MyFeature.Enabled(ctx) {
// // new code...
// } else {
// // new code...
// }
// ```
//
// and the `/api/v2/flags` response provides the same information to the frontend.
//
// ```json
// {
// "myFeature": false
// }
// ```
//
// While `false` by default, I can turn on my experimental feature by starting
// my server with a flag override.
//
// ```
// env INFLUXD_FEATURE_FLAGS="{\"flag1\":\value1\",\"key2\":\"value2\"}" influxd
// ```
//
// ```
// influxd --feature-flags flag1=value1,flag2=value2
// ```
//
package feature

140
kit/feature/feature.go Normal file
View File

@ -0,0 +1,140 @@
package feature
import (
"context"
"strings"
"github.com/opentracing/opentracing-go"
)
type contextKey string
const featureContextKey contextKey = "influx/feature/v1"
// Flagger returns flag values.
type Flagger interface {
// Flags returns a map of flag keys to flag values.
//
// If an authorization is present on the context, it may be used to compute flag
// values according to the affiliated user ID and its organization and other mappings.
// Otherwise, they should be computed generally or return a default.
//
// One or more flags may be provided to restrict the results.
// Otherwise, all flags should be computed.
Flags(context.Context, ...Flag) (map[string]interface{}, error)
}
// Annotate the context with a map computed of computed flags.
func Annotate(ctx context.Context, f Flagger, flags ...Flag) (context.Context, error) {
computed, err := f.Flags(ctx, flags...)
if err != nil {
return nil, err
}
span := opentracing.SpanFromContext(ctx)
if span != nil {
for k, v := range computed {
span.LogKV(k, v)
}
}
return context.WithValue(ctx, featureContextKey, computed), nil
}
// FlagsFromContext returns the map of flags attached to the context
// by Annotate, or nil if none is found.
func FlagsFromContext(ctx context.Context) map[string]interface{} {
v, ok := ctx.Value(featureContextKey).(map[string]interface{})
if !ok {
return nil
}
return v
}
type ByKeyFn func(string) (Flag, bool)
// ExposedFlagsFromContext returns the filtered map of exposed flags attached
// to the context by Annotate, or nil if none is found.
func ExposedFlagsFromContext(ctx context.Context, byKey ByKeyFn) map[string]interface{} {
m := FlagsFromContext(ctx)
if m == nil {
return nil
}
filtered := make(map[string]interface{})
for k, v := range m {
if flag, found := byKey(k); found && flag.Expose() {
filtered[k] = v
}
}
return filtered
}
// Lifetime represents the intended lifetime of the feature flag.
//
// The zero value is Temporary, the most common case, but Permanent
// is included to mark special cases where a flag is not intended
// to be removed, e.g. enabling debug tracing for an organization.
//
// TODO(gavincabbage): This may become a stale date, which can then
// be used to trigger a notification to the contact when the flag
// has become stale, to encourage flag cleanup.
type Lifetime int
const (
// Temporary indicates a flag is intended to be removed after a feature is no longer in development.
Temporary Lifetime = iota
// Permanent indicates a flag is not intended to be removed.
Permanent
)
// UnmarshalYAML implements yaml.Unmarshaler and interprets a case-insensitive text
// representation as a lifetime constant.
func (l *Lifetime) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
switch strings.ToLower(s) {
case "permanent":
*l = Permanent
default:
*l = Temporary
}
return nil
}
type defaultFlagger struct{}
// DefaultFlagger returns a flagger that always returns default values.
func DefaultFlagger() Flagger {
return &defaultFlagger{}
}
// Flags returns a map of default values. It never returns an error.
func (*defaultFlagger) Flags(_ context.Context, flags ...Flag) (map[string]interface{}, error) {
if len(flags) == 0 {
flags = Flags()
}
m := make(map[string]interface{}, len(flags))
for _, flag := range flags {
m[flag.Key()] = flag.Default()
}
return m, nil
}
// Flags returns all feature flags.
func Flags() []Flag {
return all
}
// ByKey returns the Flag corresponding to the given key.
func ByKey(k string) (Flag, bool) {
v, found := byKey[k]
return v, found
}

185
kit/feature/feature_test.go Normal file
View File

@ -0,0 +1,185 @@
package feature_test
import (
"context"
"testing"
"github.com/influxdata/influxdb/kit/feature"
)
func Test_feature(t *testing.T) {
cases := []struct {
name string
flag feature.Flag
err error
values map[string]interface{}
ctx context.Context
expected interface{}
}{
{
name: "bool happy path",
flag: newFlag("test", false),
values: map[string]interface{}{
"test": true,
},
expected: true,
},
{
name: "int happy path",
flag: newFlag("test", 0),
values: map[string]interface{}{
"test": int32(42),
},
expected: int32(42),
},
{
name: "float happy path",
flag: newFlag("test", 0.0),
values: map[string]interface{}{
"test": 42.42,
},
expected: 42.42,
},
{
name: "string happy path",
flag: newFlag("test", ""),
values: map[string]interface{}{
"test": "restaurantattheendoftheuniverse",
},
expected: "restaurantattheendoftheuniverse",
},
{
name: "bool missing use default",
flag: newFlag("test", false),
expected: false,
},
{
name: "bool missing use default true",
flag: newFlag("test", true),
expected: true,
},
{
name: "int missing use default",
flag: newFlag("test", 65),
expected: int32(65),
},
{
name: "float missing use default",
flag: newFlag("test", 65.65),
expected: 65.65,
},
{
name: "string missing use default",
flag: newFlag("test", "mydefault"),
expected: "mydefault",
},
{
name: "bool invalid use default",
flag: newFlag("test", true),
values: map[string]interface{}{
"test": "notabool",
},
expected: true,
},
{
name: "int invalid use default",
flag: newFlag("test", 42),
values: map[string]interface{}{
"test": 99.99,
},
expected: int32(42),
},
{
name: "float invalid use default",
flag: newFlag("test", 42.42),
values: map[string]interface{}{
"test": 99,
},
expected: 42.42,
},
{
name: "string invalid use default",
flag: newFlag("test", "restaurantattheendoftheuniverse"),
values: map[string]interface{}{
"test": true,
},
expected: "restaurantattheendoftheuniverse",
},
}
for _, test := range cases {
t.Run("flagger "+test.name, func(t *testing.T) {
flagger := testFlagsFlagger{
m: test.values,
err: test.err,
}
var actual interface{}
switch flag := test.flag.(type) {
case feature.BoolFlag:
actual = flag.Enabled(test.ctx, flagger)
case feature.FloatFlag:
actual = flag.Float(test.ctx, flagger)
case feature.IntFlag:
actual = flag.Int(test.ctx, flagger)
case feature.StringFlag:
actual = flag.String(test.ctx, flagger)
default:
t.Errorf("unknown flag type %T (%#v)", flag, flag)
}
if actual != test.expected {
t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected)
}
})
t.Run("annotate "+test.name, func(t *testing.T) {
flagger := testFlagsFlagger{
m: test.values,
err: test.err,
}
ctx, err := feature.Annotate(context.Background(), flagger)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
var actual interface{}
switch flag := test.flag.(type) {
case feature.BoolFlag:
actual = flag.Enabled(ctx)
case feature.FloatFlag:
actual = flag.Float(ctx)
case feature.IntFlag:
actual = flag.Int(ctx)
case feature.StringFlag:
actual = flag.String(ctx)
default:
t.Errorf("unknown flag type %T (%#v)", flag, flag)
}
if actual != test.expected {
t.Errorf("unexpected flag value: got %v, want %v", actual, test.expected)
}
})
}
}
type testFlagsFlagger struct {
m map[string]interface{}
err error
}
func (f testFlagsFlagger) Flags(ctx context.Context, flags ...feature.Flag) (map[string]interface{}, error) {
if f.err != nil {
return nil, f.err
}
return f.m, nil
}
func newFlag(key string, defaultValue interface{}) feature.Flag {
return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false)
}

216
kit/feature/flag.go Normal file
View File

@ -0,0 +1,216 @@
//go:generate go run ./_codegen/main.go --in ../../flags.yml --out ./list.go
package feature
import (
"context"
"fmt"
)
// Flag represents a generic feature flag with a key and a default.
type Flag interface {
// Key returns the programmatic backend identifier for the flag.
Key() string
// Default returns the type-agnostic zero value for the flag.
// Type-specific flag implementations may expose a typed default
// (e.g. BoolFlag includes a boolean Default field).
Default() interface{}
// Expose the flag.
Expose() bool
}
// MakeFlag constructs a Flag. The concrete implementation is inferred from the provided default.
func MakeFlag(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Flag {
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
switch v := defaultValue.(type) {
case bool:
return BoolFlag{b, v}
case float64:
return FloatFlag{b, v}
case int32:
return IntFlag{b, v}
case int:
return IntFlag{b, int32(v)}
case string:
return StringFlag{b, v}
default:
return StringFlag{b, fmt.Sprintf("%v", v)}
}
}
// flag base type.
type Base struct {
// name of the flag.
name string
// key is the programmatic backend identifier for the flag.
key string
// defaultValue for the flag.
defaultValue interface{}
// owner is an individual or team responsible for the flag.
owner string
// lifetime of the feature flag.
lifetime Lifetime
// expose the flag.
expose bool
}
var _ Flag = Base{}
// MakeBase constructs a flag flag.
func MakeBase(name, key, owner string, defaultValue interface{}, lifetime Lifetime, expose bool) Base {
return Base{
name: name,
key: key,
owner: owner,
defaultValue: defaultValue,
lifetime: lifetime,
expose: expose,
}
}
// Key returns the programmatic backend identifier for the flag.
func (f Base) Key() string {
return f.key
}
// Default returns the type-agnostic zero value for the flag.
func (f Base) Default() interface{} {
return f.defaultValue
}
// Expose the flag.
func (f Base) Expose() bool {
return f.expose
}
func (f Base) value(ctx context.Context, flagger ...Flagger) (interface{}, bool) {
var (
m map[string]interface{}
ok bool
)
if len(flagger) < 1 {
m, ok = ctx.Value(featureContextKey).(map[string]interface{})
} else {
var err error
m, err = flagger[0].Flags(ctx, f)
ok = err == nil
}
if !ok {
return nil, false
}
v, ok := m[f.Key()]
if !ok {
return nil, false
}
return v, true
}
// StringFlag implements Flag for string values.
type StringFlag struct {
Base
defaultString string
}
var _ Flag = StringFlag{}
// MakeStringFlag returns a string flag with the given Base and default.
func MakeStringFlag(name, key, owner string, defaultValue string, lifetime Lifetime, expose bool) StringFlag {
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
return StringFlag{b, defaultValue}
}
// String value of the flag on the request context.
func (f StringFlag) String(ctx context.Context, flagger ...Flagger) string {
i, ok := f.value(ctx, flagger...)
if !ok {
return f.defaultString
}
s, ok := i.(string)
if !ok {
return f.defaultString
}
return s
}
// FloatFlag implements Flag for float values.
type FloatFlag struct {
Base
defaultFloat float64
}
var _ Flag = FloatFlag{}
// MakeFloatFlag returns a string flag with the given Base and default.
func MakeFloatFlag(name, key, owner string, defaultValue float64, lifetime Lifetime, expose bool) FloatFlag {
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
return FloatFlag{b, defaultValue}
}
// Float value of the flag on the request context.
func (f FloatFlag) Float(ctx context.Context, flagger ...Flagger) float64 {
i, ok := f.value(ctx, flagger...)
if !ok {
return f.defaultFloat
}
v, ok := i.(float64)
if !ok {
return f.defaultFloat
}
return v
}
// IntFlag implements Flag for integer values.
type IntFlag struct {
Base
defaultInt int32
}
var _ Flag = IntFlag{}
// MakeIntFlag returns a string flag with the given Base and default.
func MakeIntFlag(name, key, owner string, defaultValue int32, lifetime Lifetime, expose bool) IntFlag {
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
return IntFlag{b, defaultValue}
}
// Int value of the flag on the request context.
func (f IntFlag) Int(ctx context.Context, flagger ...Flagger) int32 {
i, ok := f.value(ctx, flagger...)
if !ok {
return f.defaultInt
}
v, ok := i.(int32)
if !ok {
return f.defaultInt
}
return v
}
// BoolFlag implements Flag for boolean values.
type BoolFlag struct {
Base
defaultBool bool
}
var _ Flag = BoolFlag{}
// MakeBoolFlag returns a string flag with the given Base and default.
func MakeBoolFlag(name, key, owner string, defaultValue bool, lifetime Lifetime, expose bool) BoolFlag {
b := MakeBase(name, key, owner, defaultValue, lifetime, expose)
return BoolFlag{b, defaultValue}
}
// Enabled indicates whether flag is true or false on the request context.
func (f BoolFlag) Enabled(ctx context.Context, flagger ...Flagger) bool {
i, ok := f.value(ctx, flagger...)
if !ok {
return f.defaultBool
}
v, ok := i.(bool)
if !ok {
return f.defaultBool
}
return v
}

73
kit/feature/http_proxy.go Normal file
View File

@ -0,0 +1,73 @@
package feature
import (
"context"
"net/http"
"net/http/httputil"
"net/url"
"go.uber.org/zap"
)
// HTTPProxy is an HTTP proxy that's guided by a feature flag. If the feature flag
// presented to it is enabled, it will perform the proxying behavior. Otherwise
// it will be a no-op.
type HTTPProxy struct {
proxy *httputil.ReverseProxy
logger *zap.Logger
enabler ProxyEnabler
}
// NewHTTPProxy returns a new Proxy.
func NewHTTPProxy(dest *url.URL, logger *zap.Logger, enabler ProxyEnabler) *HTTPProxy {
return &HTTPProxy{
proxy: newReverseProxy(dest, enabler.Key()),
logger: logger,
enabler: enabler,
}
}
// Do performs the proxying. It returns whether or not the request was proxied.
func (p *HTTPProxy) Do(w http.ResponseWriter, r *http.Request) bool {
if p.enabler.Enabled(r.Context()) {
p.proxy.ServeHTTP(w, r)
return true
}
return false
}
const (
// headerProxyFlag is the HTTP header for enriching the request and response
// with the feature flag key that precipitated the proxying behavior.
headerProxyFlag = "X-Platform-Proxy-Flag"
)
// newReverseProxy creates a new single-host reverse proxy.
func newReverseProxy(dest *url.URL, enablerKey string) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(dest)
defaultDirector := proxy.Director
proxy.Director = func(r *http.Request) {
defaultDirector(r)
r.Header.Set(headerProxyFlag, enablerKey)
// Override r.Host to prevent us sending this request back to ourselves.
// A bug in the stdlib causes this value to be preferred over the
// r.URL.Host (which is set in the default Director) if r.Host isn't
// empty (which it isn't).
// https://github.com/golang/go/issues/28168
r.Host = dest.Host
}
proxy.ModifyResponse = func(r *http.Response) error {
r.Header.Set(headerProxyFlag, enablerKey)
return nil
}
return proxy
}
// ProxyEnabler is a boolean feature flag.
type ProxyEnabler interface {
Key() string
Enabled(ctx context.Context, fs ...Flagger) bool
}

View File

@ -0,0 +1,137 @@
package feature
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
const (
destBody = "hello from destination"
srcBody = "hello from source"
flagKey = "fancy-feature"
)
func TestHTTPProxy_Proxying(t *testing.T) {
en := enabler{key: flagKey, state: true}
logger := zaptest.NewLogger(t)
resp, err := testHTTPProxy(logger, en)
if err != nil {
t.Error(err)
}
proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag")
if proxyFlag != flagKey {
t.Error("X-Platform-Proxy-Flag header not populated")
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
bodyStr := string(body)
if bodyStr != destBody {
t.Errorf("expected body of destination handler, but got: %q", bodyStr)
}
}
func TestHTTPProxy_DefaultBehavior(t *testing.T) {
en := enabler{key: flagKey, state: false}
logger := zaptest.NewLogger(t)
resp, err := testHTTPProxy(logger, en)
if err != nil {
t.Error(err)
}
proxyFlag := resp.Header.Get("X-Platform-Proxy-Flag")
if proxyFlag != "" {
t.Error("X-Platform-Proxy-Flag header populated")
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
bodyStr := string(body)
if bodyStr != srcBody {
t.Errorf("expected body of source handler, but got: %q", bodyStr)
}
}
func TestHTTPProxy_RequestHeader(t *testing.T) {
h := func(w http.ResponseWriter, r *http.Request) {
proxyFlag := r.Header.Get("X-Platform-Proxy-Flag")
if proxyFlag != flagKey {
t.Error("expected X-Proxy-Flag to contain feature flag key")
}
}
s := httptest.NewServer(http.HandlerFunc(h))
defer s.Close()
sURL, err := url.Parse(s.URL)
if err != nil {
t.Error(err)
}
logger := zaptest.NewLogger(t)
en := enabler{key: flagKey, state: true}
proxy := NewHTTPProxy(sURL, logger, en)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
srcHandler(proxy)(w, r)
}
func testHTTPProxy(logger *zap.Logger, enabler ProxyEnabler) (*http.Response, error) {
s := httptest.NewServer(http.HandlerFunc(destHandler))
defer s.Close()
sURL, err := url.Parse(s.URL)
if err != nil {
return nil, err
}
proxy := NewHTTPProxy(sURL, logger, enabler)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
srcHandler(proxy)(w, r)
return w.Result(), nil
}
func destHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, destBody)
}
func srcHandler(proxy *HTTPProxy) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if proxy.Do(w, r) {
return
}
fmt.Fprint(w, srcBody)
}
}
type enabler struct {
key string
state bool
}
func (e enabler) Key() string {
return e.key
}
func (e enabler) Enabled(context.Context, ...Flagger) bool {
return e.state
}

249
kit/feature/list.go Normal file
View File

@ -0,0 +1,249 @@
// Code generated by the feature package; DO NOT EDIT.
package feature
var appMetrics = MakeBoolFlag(
"App Metrics",
"appMetrics",
"Bucky, Monitoring Team",
false,
Permanent,
true,
)
// AppMetrics - Send UI Telementry to Tools cluster - should always be false in OSS
func AppMetrics() BoolFlag {
return appMetrics
}
var backendExample = MakeBoolFlag(
"Backend Example",
"backendExample",
"Gavin Cabbage",
false,
Permanent,
false,
)
// BackendExample - A permanent backend example boolean flag
func BackendExample() BoolFlag {
return backendExample
}
var communityTemplates = MakeBoolFlag(
"Community Templates",
"communityTemplates",
"Bucky",
true,
Permanent,
true,
)
// CommunityTemplates - Replace current template uploading functionality with community driven templates
func CommunityTemplates() BoolFlag {
return communityTemplates
}
var frontendExample = MakeIntFlag(
"Frontend Example",
"frontendExample",
"Gavin Cabbage",
42,
Temporary,
true,
)
// FrontendExample - A temporary frontend example integer flag
func FrontendExample() IntFlag {
return frontendExample
}
var groupWindowAggregateTranspose = MakeBoolFlag(
"Group Window Aggregate Transpose",
"groupWindowAggregateTranspose",
"Query Team",
false,
Temporary,
false,
)
// GroupWindowAggregateTranspose - Enables the GroupWindowAggregateTransposeRule for all enabled window aggregates
func GroupWindowAggregateTranspose() BoolFlag {
return groupWindowAggregateTranspose
}
var newLabels = MakeBoolFlag(
"New Label Package",
"newLabels",
"Alirie Gray",
false,
Temporary,
false,
)
// NewLabelPackage - Enables the refactored labels api
func NewLabelPackage() BoolFlag {
return newLabels
}
var memoryOptimizedFill = MakeBoolFlag(
"Memory Optimized Fill",
"memoryOptimizedFill",
"Query Team",
false,
Temporary,
false,
)
// MemoryOptimizedFill - Enable the memory optimized fill()
func MemoryOptimizedFill() BoolFlag {
return memoryOptimizedFill
}
var memoryOptimizedSchemaMutation = MakeBoolFlag(
"Memory Optimized Schema Mutation",
"memoryOptimizedSchemaMutation",
"Query Team",
false,
Temporary,
false,
)
// MemoryOptimizedSchemaMutation - Enable the memory optimized schema mutation functions
func MemoryOptimizedSchemaMutation() BoolFlag {
return memoryOptimizedSchemaMutation
}
var queryTracing = MakeBoolFlag(
"Query Tracing",
"queryTracing",
"Query Team",
false,
Permanent,
false,
)
// QueryTracing - Turn on query tracing for queries that are sampled
func QueryTracing() BoolFlag {
return queryTracing
}
var bandPlotType = MakeBoolFlag(
"Band Plot Type",
"bandPlotType",
"Monitoring Team",
false,
Temporary,
true,
)
// BandPlotType - Enables the creation of a band plot in Dashboards
func BandPlotType() BoolFlag {
return bandPlotType
}
var mosaicGraphType = MakeBoolFlag(
"Mosaic Graph Type",
"mosaicGraphType",
"Monitoring Team",
false,
Temporary,
true,
)
// MosaicGraphType - Enables the creation of a mosaic graph in Dashboards
func MosaicGraphType() BoolFlag {
return mosaicGraphType
}
var notebooks = MakeBoolFlag(
"Notebooks",
"notebooks",
"Monitoring Team",
false,
Temporary,
true,
)
// Notebooks - Determine if the notebook feature's route and navbar icon are visible to the user
func Notebooks() BoolFlag {
return notebooks
}
var injectLatestSuccessTime = MakeBoolFlag(
"Inject Latest Success Time",
"injectLatestSuccessTime",
"Compute Team",
false,
Temporary,
false,
)
// InjectLatestSuccessTime - Inject the latest successful task run timestamp into a Task query extern when executing.
func InjectLatestSuccessTime() BoolFlag {
return injectLatestSuccessTime
}
var enforceOrgDashboardLimits = MakeBoolFlag(
"Enforce Organization Dashboard Limits",
"enforceOrgDashboardLimits",
"Compute Team",
false,
Temporary,
false,
)
// EnforceOrganizationDashboardLimits - Enforces the default limit params for the dashboards api when orgs are set
func EnforceOrganizationDashboardLimits() BoolFlag {
return enforceOrgDashboardLimits
}
var timeFilterFlags = MakeBoolFlag(
"Time Filter Flags",
"timeFilterFlags",
"Compute Team",
false,
Temporary,
true,
)
// TimeFilterFlags - Filter task run list based on before and after flags
func TimeFilterFlags() BoolFlag {
return timeFilterFlags
}
var all = []Flag{
appMetrics,
backendExample,
communityTemplates,
frontendExample,
groupWindowAggregateTranspose,
newLabels,
memoryOptimizedFill,
memoryOptimizedSchemaMutation,
queryTracing,
bandPlotType,
mosaicGraphType,
notebooks,
injectLatestSuccessTime,
enforceOrgDashboardLimits,
timeFilterFlags,
}
var byKey = map[string]Flag{
"appMetrics": appMetrics,
"backendExample": backendExample,
"communityTemplates": communityTemplates,
"frontendExample": frontendExample,
"groupWindowAggregateTranspose": groupWindowAggregateTranspose,
"newLabels": newLabels,
"memoryOptimizedFill": memoryOptimizedFill,
"memoryOptimizedSchemaMutation": memoryOptimizedSchemaMutation,
"queryTracing": queryTracing,
"bandPlotType": bandPlotType,
"mosaicGraphType": mosaicGraphType,
"notebooks": notebooks,
"injectLatestSuccessTime": injectLatestSuccessTime,
"enforceOrgDashboardLimits": enforceOrgDashboardLimits,
"timeFilterFlags": timeFilterFlags,
}

70
kit/feature/middleware.go Normal file
View File

@ -0,0 +1,70 @@
package feature
import (
"context"
"encoding/json"
"net/http"
"go.uber.org/zap"
)
// Handler is a middleware that annotates the context with a map of computed feature flags.
// To accurately compute identity-scoped flags, this middleware should be executed after any
// authorization middleware has annotated the request context with an authorizer.
type Handler struct {
log *zap.Logger
next http.Handler
flagger Flagger
flags []Flag
}
// NewHandler returns a configured feature flag middleware that will annotate request context
// with a computed map of the given flags using the provided Flagger.
func NewHandler(log *zap.Logger, flagger Flagger, flags []Flag, next http.Handler) http.Handler {
return &Handler{
log: log,
next: next,
flagger: flagger,
flags: flags,
}
}
// ServeHTTP annotates the request context with a map of computed feature flags before
// continuing to serve the request.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, err := Annotate(r.Context(), h.flagger, h.flags...)
if err != nil {
h.log.Warn("Unable to annotate context with feature flags", zap.Error(err))
} else {
r = r.WithContext(ctx)
}
if h.next != nil {
h.next.ServeHTTP(w, r)
}
}
// HTTPErrorHandler is an influxdb.HTTPErrorHandler. It's defined here instead
// of referencing the other interface type, because we want to try our best to
// avoid cyclical dependencies when feature package is used throughout the
// codebase.
type HTTPErrorHandler interface {
HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter)
}
// NewFlagsHandler returns a handler that returns the map of computed feature flags on the request context.
func NewFlagsHandler(errorHandler HTTPErrorHandler, byKey ByKeyFn) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
var (
ctx = r.Context()
flags = ExposedFlagsFromContext(ctx, byKey)
)
if err := json.NewEncoder(w).Encode(flags); err != nil {
errorHandler.HandleHTTPError(ctx, err, w)
}
}
return http.HandlerFunc(fn)
}

View File

@ -0,0 +1,47 @@
package feature_test
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/influxdata/influxdb/kit/feature"
"go.uber.org/zap/zaptest"
)
func Test_Handler(t *testing.T) {
var (
w = &httptest.ResponseRecorder{}
r = httptest.NewRequest(http.MethodGet, "http://nowhere.test", new(bytes.Buffer)).
WithContext(context.Background())
original = r.Context()
)
handler := &checkHandler{t: t, f: func(t *testing.T, r *http.Request) {
if r.Context() == original {
t.Error("expected annotated context")
}
}}
subject := feature.NewHandler(zaptest.NewLogger(t), feature.DefaultFlagger(), feature.Flags(), handler)
subject.ServeHTTP(w, r)
if !handler.called {
t.Error("expected handler to be called")
}
}
type checkHandler struct {
t *testing.T
f func(t *testing.T, r *http.Request)
called bool
}
func (h *checkHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) {
h.called = true
h.f(h.t, r)
}

View File

@ -0,0 +1,83 @@
package override
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/influxdata/influxdb/kit/feature"
)
// Flagger can override default flag values.
type Flagger struct {
overrides map[string]string
byKey feature.ByKeyFn
}
// Make a Flagger that returns defaults with any overrides parsed from the string.
func Make(overrides map[string]string, byKey feature.ByKeyFn) (Flagger, error) {
if byKey == nil {
byKey = feature.ByKey
}
// Check all provided override keys correspond to an existing Flag.
var missing []string
for k := range overrides {
if _, found := byKey(k); !found {
missing = append(missing, k)
}
}
if len(missing) > 0 {
return Flagger{}, fmt.Errorf("configured overrides for non-existent flags: %s", strings.Join(missing, ","))
}
return Flagger{
overrides: overrides,
byKey: byKey,
}, nil
}
// Flags returns a map of default values with overrides applied. It never returns an error.
func (f Flagger) Flags(_ context.Context, flags ...feature.Flag) (map[string]interface{}, error) {
if len(flags) == 0 {
flags = feature.Flags()
}
m := make(map[string]interface{}, len(flags))
for _, flag := range flags {
if s, overridden := f.overrides[flag.Key()]; overridden {
iface, err := f.coerce(s, flag)
if err != nil {
return nil, err
}
m[flag.Key()] = iface
} else {
m[flag.Key()] = flag.Default()
}
}
return m, nil
}
func (f Flagger) coerce(s string, flag feature.Flag) (iface interface{}, err error) {
if base, ok := flag.(feature.Base); ok {
flag, _ = f.byKey(base.Key())
}
switch flag.(type) {
case feature.BoolFlag:
iface, err = strconv.ParseBool(s)
case feature.IntFlag:
iface, err = strconv.Atoi(s)
case feature.FloatFlag:
iface, err = strconv.ParseFloat(s, 64)
default:
iface = s
}
if err != nil {
return nil, fmt.Errorf("coercing string %q based on flag type %T: %v", s, flag, err)
}
return
}

View File

@ -0,0 +1,183 @@
package override
import (
"context"
"testing"
"github.com/influxdata/influxdb/kit/feature"
)
func TestFlagger(t *testing.T) {
cases := []struct {
name string
env map[string]string
defaults []feature.Flag
expected map[string]interface{}
expectMakeErr bool
expectFlagsErr bool
byKey feature.ByKeyFn
}{
{
name: "enabled happy path filtering",
env: map[string]string{
"flag1": "new1",
"flag3": "new3",
},
defaults: []feature.Flag{
newFlag("flag0", "original0"),
newFlag("flag1", "original1"),
newFlag("flag2", "original2"),
newFlag("flag3", "original3"),
newFlag("flag4", "original4"),
},
byKey: newByKey(map[string]feature.Flag{
"flag0": newFlag("flag0", "original0"),
"flag1": newFlag("flag1", "original1"),
"flag2": newFlag("flag2", "original2"),
"flag3": newFlag("flag3", "original3"),
"flag4": newFlag("flag4", "original4"),
}),
expected: map[string]interface{}{
"flag0": "original0",
"flag1": "new1",
"flag2": "original2",
"flag3": "new3",
"flag4": "original4",
},
},
{
name: "enabled happy path types",
env: map[string]string{
"intflag": "43",
"floatflag": "43.43",
"boolflag": "true",
},
defaults: []feature.Flag{
newFlag("intflag", 42),
newFlag("floatflag", 42.42),
newFlag("boolflag", false),
},
byKey: newByKey(map[string]feature.Flag{
"intflag": newFlag("intflag", 42),
"floatflag": newFlag("floatflag", 43.43),
"boolflag": newFlag("boolflag", false),
}),
expected: map[string]interface{}{
"intflag": 43,
"floatflag": 43.43,
"boolflag": true,
},
},
{
name: "type coerce error",
env: map[string]string{
"key": "not_an_int",
},
defaults: []feature.Flag{
newFlag("key", 42),
},
byKey: newByKey(map[string]feature.Flag{
"key": newFlag("key", 42),
}),
expectFlagsErr: true,
},
{
name: "typed base flags",
env: map[string]string{
"flag1": "411",
"flag2": "new2",
"flag3": "true",
},
defaults: []feature.Flag{
newBaseFlag("flag0", "original0"),
newBaseFlag("flag1", 41),
newBaseFlag("flag2", "original2"),
newBaseFlag("flag3", false),
newBaseFlag("flag4", "original4"),
},
byKey: newByKey(map[string]feature.Flag{
"flag0": newFlag("flag0", "original0"),
"flag1": newFlag("flag1", 41),
"flag2": newFlag("flag2", "original2"),
"flag3": newFlag("flag3", false),
"flag4": newFlag("flag4", "original4"),
}),
expected: map[string]interface{}{
"flag0": "original0",
"flag1": 411,
"flag2": "new2",
"flag3": true,
"flag4": "original4",
},
},
{
name: "override for non-existent flag",
env: map[string]string{
"dne": "foobar",
},
defaults: []feature.Flag{
newBaseFlag("key", "value"),
},
byKey: newByKey(map[string]feature.Flag{
"key": newFlag("key", "value"),
}),
expectMakeErr: true,
},
}
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
subject, err := Make(test.env, test.byKey)
if err != nil {
if test.expectMakeErr {
return
}
t.Fatalf("unexpected error making Flagger: %v", err)
}
computed, err := subject.Flags(context.Background(), test.defaults...)
if err != nil {
if test.expectFlagsErr {
return
}
t.Fatalf("unexpected error calling Flags: %v", err)
}
if len(computed) != len(test.expected) {
t.Fatalf("incorrect number of flags computed: expected %d, got %d", len(test.expected), len(computed))
}
// check for extra or incorrect keys
for k, v := range computed {
if xv, found := test.expected[k]; !found {
t.Errorf("unexpected key %s", k)
} else if v != xv {
t.Errorf("incorrect value for key %s: expected %v [%T], got %v [%T]", k, xv, xv, v, v)
}
}
// check for missing keys
for k := range test.expected {
if _, found := computed[k]; !found {
t.Errorf("missing expected key %s", k)
}
}
})
}
}
func newFlag(key string, defaultValue interface{}) feature.Flag {
return feature.MakeFlag(key, key, "", defaultValue, feature.Temporary, false)
}
func newBaseFlag(key string, defaultValue interface{}) feature.Base {
return feature.MakeBase(key, key, "", defaultValue, feature.Temporary, false)
}
func newByKey(m map[string]feature.Flag) feature.ByKeyFn {
return func(k string) (feature.Flag, bool) {
v, found := m[k]
return v, found
}
}

View File

@ -0,0 +1,59 @@
package io
import (
"errors"
"io"
)
var ErrReadLimitExceeded = errors.New("read limit exceeded")
// LimitedReadCloser wraps an io.ReadCloser in limiting behavior using
// io.LimitedReader. It allows us to obtain the limit error at the time of close
// instead of just when writing.
type LimitedReadCloser struct {
R io.ReadCloser // underlying reader
N int64 // max bytes remaining
err error
closed bool
limitExceeded bool
}
// NewLimitedReadCloser returns a new LimitedReadCloser.
func NewLimitedReadCloser(r io.ReadCloser, n int64) *LimitedReadCloser {
return &LimitedReadCloser{
R: r,
N: n,
}
}
func (l *LimitedReadCloser) Read(p []byte) (n int, err error) {
if l.N <= 0 {
l.limitExceeded = true
return 0, io.EOF
}
if int64(len(p)) > l.N {
p = p[0:l.N]
}
n, err = l.R.Read(p)
l.N -= int64(n)
return
}
// Close returns an ErrReadLimitExceeded when the wrapped reader exceeds the set
// limit for number of bytes. This is safe to call more than once but not
// concurrently.
func (l *LimitedReadCloser) Close() (err error) {
if l.limitExceeded {
l.err = ErrReadLimitExceeded
}
if l.closed {
// Close has already been called.
return l.err
}
if err := l.R.Close(); err != nil && l.err == nil {
l.err = err
}
// Prevent l.closer.Close from being called again.
l.closed = true
return l.err
}

View File

@ -0,0 +1,87 @@
package io
import (
"bytes"
"errors"
"io"
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitedReadCloser_Exceeded(t *testing.T) {
b := &closer{Reader: bytes.NewBufferString("howdy")}
rc := NewLimitedReadCloser(b, 3)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("how"), out)
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
}
func TestLimitedReadCloser_Happy(t *testing.T) {
b := &closer{Reader: bytes.NewBufferString("ho")}
rc := NewLimitedReadCloser(b, 2)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("ho"), out)
assert.Nil(t, err)
}
func TestLimitedReadCloseWithErrorAndLimitExceeded(t *testing.T) {
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: errors.New("some error"),
}
rc := NewLimitedReadCloser(b, 3)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("how"), out)
// LimitExceeded error trumps the close error.
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
}
func TestLimitedReadCloseWithError(t *testing.T) {
closeErr := errors.New("some error")
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: closeErr,
}
rc := NewLimitedReadCloser(b, 10)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("howdy"), out)
assert.Equal(t, closeErr, rc.Close())
}
func TestMultipleCloseOnlyClosesOnce(t *testing.T) {
closeErr := errors.New("some error")
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: closeErr,
}
rc := NewLimitedReadCloser(b, 10)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("howdy"), out)
assert.Equal(t, closeErr, rc.Close())
assert.Equal(t, closeErr, rc.Close())
assert.Equal(t, 1, b.closeCount)
}
type closer struct {
io.Reader
err error
closeCount int
}
func (c *closer) Close() error {
c.closeCount++
return c.err
}

152
kit/metric/client.go Normal file
View File

@ -0,0 +1,152 @@
package metric
import (
"time"
"github.com/influxdata/influxdb/kit/platform/errors"
"github.com/prometheus/client_golang/prometheus"
)
// REDClient is a metrics client for collection RED metrics.
type REDClient struct {
metrics []metricCollector
}
// New creates a new REDClient.
func New(reg prometheus.Registerer, service string, opts ...ClientOptFn) *REDClient {
opt := metricOpts{
namespace: "service",
service: service,
counterMetrics: map[string]VecOpts{
"call_total": {
Help: "Number of calls",
LabelNames: []string{"method"},
CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) {
vec.With(prometheus.Labels{"method": o.Method}).Inc()
},
},
"error_total": {
Help: "Number of errors encountered",
LabelNames: []string{"method", "code"},
CounterFn: func(vec *prometheus.CounterVec, o CollectFnOpts) {
if o.Err != nil {
vec.With(prometheus.Labels{
"method": o.Method,
"code": errors.ErrorCode(o.Err),
}).Inc()
}
},
},
},
histogramMetrics: map[string]VecOpts{
"duration": {
Help: "Duration of calls",
LabelNames: []string{"method"},
HistogramFn: func(vec *prometheus.HistogramVec, o CollectFnOpts) {
vec.
With(prometheus.Labels{"method": o.Method}).
Observe(time.Since(o.Start).Seconds())
},
},
},
}
for _, o := range opts {
o(&opt)
}
client := new(REDClient)
for metricName, vecOpts := range opt.counterMetrics {
client.metrics = append(client.metrics, &counter{
fn: vecOpts.CounterFn,
CounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: opt.namespace,
Subsystem: opt.serviceName(),
Name: metricName,
Help: vecOpts.Help,
}, vecOpts.LabelNames),
})
}
for metricName, vecOpts := range opt.histogramMetrics {
client.metrics = append(client.metrics, &histogram{
fn: vecOpts.HistogramFn,
HistogramVec: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: opt.namespace,
Subsystem: opt.serviceName(),
Name: metricName,
Help: vecOpts.Help,
}, vecOpts.LabelNames),
})
}
reg.MustRegister(client.collectors()...)
return client
}
type RecordFn func(err error, opts ...func(opts *CollectFnOpts)) error
// RecordAdditional provides an extension to the base method, err data provided
// to the metrics.
func RecordAdditional(props map[string]interface{}) func(opts *CollectFnOpts) {
return func(opts *CollectFnOpts) {
opts.AdditionalProps = props
}
}
// Record returns a record fn that is called on any given return err. If an error is encountered
// it will register the err metric. The err is never altered.
func (c *REDClient) Record(method string) RecordFn {
start := time.Now()
return func(err error, opts ...func(opts *CollectFnOpts)) error {
opt := CollectFnOpts{
Method: method,
Start: start,
Err: err,
}
for _, o := range opts {
o(&opt)
}
for _, metric := range c.metrics {
metric.collect(opt)
}
return err
}
}
func (c *REDClient) collectors() []prometheus.Collector {
var collectors []prometheus.Collector
for _, metric := range c.metrics {
collectors = append(collectors, metric)
}
return collectors
}
type metricCollector interface {
prometheus.Collector
collect(o CollectFnOpts)
}
type counter struct {
*prometheus.CounterVec
fn CounterFn
}
func (c *counter) collect(o CollectFnOpts) {
c.fn(c.CounterVec, o)
}
type histogram struct {
*prometheus.HistogramVec
fn HistogramFn
}
func (h *histogram) collect(o CollectFnOpts) {
h.fn(h.HistogramVec, o)
}

View File

@ -0,0 +1,84 @@
package metric
import (
"fmt"
"time"
"github.com/prometheus/client_golang/prometheus"
)
type (
// CollectFnOpts provides arguments to the collect operation of a metric.
CollectFnOpts struct {
Method string
Start time.Time
Err error
AdditionalProps map[string]interface{}
}
CounterFn func(vec *prometheus.CounterVec, o CollectFnOpts)
HistogramFn func(vec *prometheus.HistogramVec, o CollectFnOpts)
// VecOpts expands on the
VecOpts struct {
Name string
Help string
LabelNames []string
CounterFn CounterFn
HistogramFn HistogramFn
}
)
type metricOpts struct {
namespace string
service string
serviceSuffix string
counterMetrics map[string]VecOpts
histogramMetrics map[string]VecOpts
}
func (o metricOpts) serviceName() string {
if o.serviceSuffix != "" {
return fmt.Sprintf("%s_%s", o.service, o.serviceSuffix)
}
return o.service
}
// ClientOptFn is an option used by a metric middleware.
type ClientOptFn func(*metricOpts)
// WithVec sets a new counter vector to be collected.
func WithVec(opts VecOpts) ClientOptFn {
return func(o *metricOpts) {
if opts.CounterFn != nil {
if o.counterMetrics == nil {
o.counterMetrics = make(map[string]VecOpts)
}
o.counterMetrics[opts.Name] = opts
}
}
}
// WithSuffix returns a metric option that applies a suffix to the service name of the metric.
func WithSuffix(suffix string) ClientOptFn {
return func(opts *metricOpts) {
opts.serviceSuffix = suffix
}
}
func ApplyMetricOpts(opts ...ClientOptFn) *metricOpts {
o := metricOpts{}
for _, opt := range opts {
opt(&o)
}
return &o
}
func (o *metricOpts) ApplySuffix(prefix string) string {
if o.serviceSuffix != "" {
return fmt.Sprintf("%s_%s", prefix, o.serviceSuffix)
}
return prefix
}

View File

@ -0,0 +1,9 @@
package errors
// ChronografError is a domain error encountered while processing chronograf requests.
type ChronografError string
// ChronografError returns the string of an error.
func (e ChronografError) Error() string {
return string(e)
}

View File

@ -0,0 +1,264 @@
package errors
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
)
// Some error code constant, ideally we want define common platform codes here
// projects on use platform's error, should have their own central place like this.
// Any time this set of constants changes, you must also update the swagger for Error.properties.code.enum.
const (
EInternal = "internal error"
ENotImplemented = "not implemented"
ENotFound = "not found"
EConflict = "conflict" // action cannot be performed
EInvalid = "invalid" // validation failed
EUnprocessableEntity = "unprocessable entity" // data type is correct, but out of range
EEmptyValue = "empty value"
EUnavailable = "unavailable"
EForbidden = "forbidden"
ETooManyRequests = "too many requests"
EUnauthorized = "unauthorized"
EMethodNotAllowed = "method not allowed"
ETooLarge = "request too large"
)
// Error is the error struct of platform.
//
// Errors may have error codes, human-readable messages,
// and a logical stack trace.
//
// The Code targets automated handlers so that recovery can occur.
// Msg is used by the system operator to help diagnose and fix the problem.
// Op and Err chain errors together in a logical stack trace to
// further help operators.
//
// To create a simple error,
// &Error{
// Code:ENotFound,
// }
// To show where the error happens, add Op.
// &Error{
// Code: ENotFound,
// Op: "bolt.FindUserByID"
// }
// To show an error with a unpredictable value, add the value in Msg.
// &Error{
// Code: EConflict,
// Message: fmt.Sprintf("organization with name %s already exist", aName),
// }
// To show an error wrapped with another error.
// &Error{
// Code:EInternal,
// Err: err,
// }.
type Error struct {
Code string
Msg string
Op string
Err error
}
// NewError returns an instance of an error.
func NewError(options ...func(*Error)) *Error {
err := &Error{}
for _, o := range options {
o(err)
}
return err
}
// WithErrorErr sets the err on the error.
func WithErrorErr(err error) func(*Error) {
return func(e *Error) {
e.Err = err
}
}
// WithErrorCode sets the code on the error.
func WithErrorCode(code string) func(*Error) {
return func(e *Error) {
e.Code = code
}
}
// WithErrorMsg sets the message on the error.
func WithErrorMsg(msg string) func(*Error) {
return func(e *Error) {
e.Msg = msg
}
}
// WithErrorOp sets the message on the error.
func WithErrorOp(op string) func(*Error) {
return func(e *Error) {
e.Op = op
}
}
// Error implements the error interface by writing out the recursive messages.
func (e *Error) Error() string {
if e.Msg != "" && e.Err != nil {
var b strings.Builder
b.WriteString(e.Msg)
b.WriteString(": ")
b.WriteString(e.Err.Error())
return b.String()
} else if e.Msg != "" {
return e.Msg
} else if e.Err != nil {
return e.Err.Error()
}
return fmt.Sprintf("<%s>", e.Code)
}
// ErrorCode returns the code of the root error, if available; otherwise returns EINTERNAL.
func ErrorCode(err error) string {
if err == nil {
return ""
}
e, ok := err.(*Error)
if !ok {
return EInternal
}
if e == nil {
return ""
}
if e.Code != "" {
return e.Code
}
if e.Err != nil {
return ErrorCode(e.Err)
}
return EInternal
}
// ErrorOp returns the op of the error, if available; otherwise return empty string.
func ErrorOp(err error) string {
if err == nil {
return ""
}
e, ok := err.(*Error)
if !ok {
return ""
}
if e == nil {
return ""
}
if e.Op != "" {
return e.Op
}
if e.Err != nil {
return ErrorOp(e.Err)
}
return ""
}
// ErrorMessage returns the human-readable message of the error, if available.
// Otherwise returns a generic error message.
func ErrorMessage(err error) string {
if err == nil {
return ""
}
e, ok := err.(*Error)
if !ok {
return "An internal error has occurred."
}
if e == nil {
return ""
}
if e.Msg != "" {
return e.Msg
}
if e.Err != nil {
return ErrorMessage(e.Err)
}
return "An internal error has occurred."
}
// errEncode an JSON encoding helper that is needed to handle the recursive stack of errors.
type errEncode struct {
Code string `json:"code"` // Code is the machine-readable error code.
Msg string `json:"message,omitempty"` // Msg is a human-readable message.
Op string `json:"op,omitempty"` // Op describes the logical code operation during error.
Err interface{} `json:"error,omitempty"` // Err is a stack of additional errors.
}
// MarshalJSON recursively marshals the stack of Err.
func (e *Error) MarshalJSON() (result []byte, err error) {
ee := errEncode{
Code: e.Code,
Msg: e.Msg,
Op: e.Op,
}
if e.Err != nil {
if _, ok := e.Err.(*Error); ok {
_, err := e.Err.(*Error).MarshalJSON()
if err != nil {
return result, err
}
ee.Err = e.Err
} else {
ee.Err = e.Err.Error()
}
}
return json.Marshal(ee)
}
// UnmarshalJSON recursively unmarshals the error stack.
func (e *Error) UnmarshalJSON(b []byte) (err error) {
ee := new(errEncode)
err = json.Unmarshal(b, ee)
e.Code = ee.Code
e.Msg = ee.Msg
e.Op = ee.Op
e.Err = decodeInternalError(ee.Err)
return err
}
func decodeInternalError(target interface{}) error {
if errStr, ok := target.(string); ok {
return errors.New(errStr)
}
if internalErrMap, ok := target.(map[string]interface{}); ok {
internalErr := new(Error)
if code, ok := internalErrMap["code"].(string); ok {
internalErr.Code = code
}
if msg, ok := internalErrMap["message"].(string); ok {
internalErr.Msg = msg
}
if op, ok := internalErrMap["op"].(string); ok {
internalErr.Op = op
}
internalErr.Err = decodeInternalError(internalErrMap["error"])
return internalErr
}
return nil
}
// HTTPErrorHandler is the interface to handle http error.
type HTTPErrorHandler interface {
HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter)
}

View File

@ -0,0 +1,86 @@
# errors.go
This is inspired from Ben Johnson's blog post [Failure is Your Domain](https://middlemost.com/failure-is-your-domain/)
## The Error struct
```go
type Error struct {
Code string
Msg string
Op string
Err error
}
```
* Code is the machine readable code, for reference purpose. All the codes should be a constant string. For example. `const ENotFound = "source not found"`.
* Msg is the human readable message for end user. For example, `Your credit card is declined.`
* Op is the logical Operator, should be a constant defined inside the function. For example: "bolt.UserCreate".
* Err is the embed error. You may embed either a third party error or and platform.Error.
## Use Case Example
We implement the following interface
```go
type OrganizationService interface {
FindOrganizationByID(ctx context.Context, id ID) (*Organization, error)
}
func (c *Client)FindOrganizationByID(ctx context.Context, id platform.ID) (*platform.Organization, error) {
var o *platform.Organization
const op = "bolt.FindOrganizationByID"
err := c.db.View(func(tx *bolt.Tx) error {
org, err := c.findOrganizationByID(ctx, tx, id)
if err != nil {
return err
}
o = org
return nil
})
if err != nil {
return nil, &platform.Error{
Code: platform.ENotFound,
Op: op,
Err: err,
}
}
return o, nil
}
```
To check the error code
```go
if platform.ErrorCode(err) == platform.ENotFound {
...
}
```
To serialize the error
```go
b, err := json.Marshal(err)
```
To deserialize the error
```go
e := new(platform.Error)
err := json.Unmarshal(b, e)
```

View File

@ -0,0 +1,300 @@
package errors_test
import (
"encoding/json"
"errors"
"fmt"
"testing"
errors2 "github.com/influxdata/influxdb/kit/platform/errors"
)
const EFailedToGetStorageHost = "failed to get the storage host"
func TestErrorMsg(t *testing.T) {
cases := []struct {
name string
err error
msg string
}{
{
name: "simple error",
err: &errors2.Error{Code: errors2.ENotFound},
msg: "<not found>",
},
{
name: "with message",
err: &errors2.Error{
Code: errors2.ENotFound,
Msg: "bucket not found",
},
msg: "bucket not found",
},
{
name: "with a third party error and no message",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Err: errors.New("empty value"),
},
msg: "empty value",
},
{
name: "with a third party error and a message",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Msg: "failed to get storage hosts",
Err: errors.New("empty value"),
},
msg: "failed to get storage hosts: empty value",
},
{
name: "with an internal error and no message",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Err: &errors2.Error{
Code: errors2.EEmptyValue,
Msg: "empty value",
},
},
msg: "empty value",
},
{
name: "with an internal error and a message",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Msg: "failed to get storage hosts",
Err: &errors2.Error{
Code: errors2.EEmptyValue,
Msg: "empty value",
},
},
msg: "failed to get storage hosts: empty value",
},
}
for _, c := range cases {
if c.msg != c.err.Error() {
t.Errorf("%s failed, want %s, got %s", c.name, c.msg, c.err.Error())
}
}
}
func TestErrorMessage(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{
name: "nil error",
},
{
name: "nil error of type *platform.Error",
err: (*errors2.Error)(nil),
},
{
name: "simple error",
err: &errors2.Error{Msg: "simple error"},
want: "simple error",
},
{
name: "embedded error",
err: &errors2.Error{Err: &errors2.Error{Msg: "embedded error"}},
want: "embedded error",
},
{
name: "default error",
err: errors.New("s"),
want: "An internal error has occurred.",
},
}
for _, c := range cases {
if result := errors2.ErrorMessage(c.err); c.want != result {
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
}
}
}
func TestErrorOp(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{
name: "nil error",
},
{
name: "nil error of type *platform.Error",
err: (*errors2.Error)(nil),
},
{
name: "simple error",
err: &errors2.Error{Op: "op1"},
want: "op1",
},
{
name: "embedded error",
err: &errors2.Error{Op: "op1", Err: &errors2.Error{Code: errors2.EInvalid}},
want: "op1",
},
{
name: "embedded error without op in root level",
err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid, Op: "op2"}},
want: "op2",
},
{
name: "default error",
err: errors.New("s"),
want: "",
},
}
for _, c := range cases {
if result := errors2.ErrorOp(c.err); c.want != result {
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
}
}
}
func TestErrorCode(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{
name: "nil error",
},
{
name: "nil error of type *platform.Error",
err: (*errors2.Error)(nil),
},
{
name: "simple error",
err: &errors2.Error{Code: errors2.ENotFound},
want: errors2.ENotFound,
},
{
name: "embedded error",
err: &errors2.Error{Code: errors2.ENotFound, Err: &errors2.Error{Code: errors2.EInvalid}},
want: errors2.ENotFound,
},
{
name: "embedded error with root level code",
err: &errors2.Error{Err: &errors2.Error{Code: errors2.EInvalid}},
want: errors2.EInvalid,
},
{
name: "default error",
err: errors.New("s"),
want: errors2.EInternal,
},
}
for _, c := range cases {
if result := errors2.ErrorCode(c.err); c.want != result {
t.Errorf("%s failed, want %s, got %s", c.name, c.want, result)
}
}
}
func TestJSON(t *testing.T) {
cases := []struct {
name string
err *errors2.Error
encoded string
}{
{
name: "simple error",
err: &errors2.Error{Code: errors2.ENotFound},
encoded: `{"code":"not found"}`,
},
{
name: "with op",
err: &errors2.Error{
Code: errors2.ENotFound,
Op: "bolt.FindAuthorizationByID",
},
encoded: `{"code":"not found","op":"bolt.FindAuthorizationByID"}`,
},
{
name: "with op and value",
err: &errors2.Error{
Code: errors2.ENotFound,
Op: "bolt/FindAuthorizationByID",
Msg: fmt.Sprintf("with ID %d", 323),
},
encoded: `{"code":"not found","message":"with ID 323","op":"bolt/FindAuthorizationByID"}`,
},
{
name: "with a third party error",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Op: "cmd/fluxd.injectDeps",
Err: errors.New("empty value"),
},
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":"empty value"}`,
},
{
name: "with a internal error",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Op: "cmd/fluxd.injectDeps",
Err: &errors2.Error{Code: errors2.EEmptyValue, Op: "cmd/fluxd.getStrList"},
},
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"empty value","op":"cmd/fluxd.getStrList"}}`,
},
{
name: "with a deep internal error",
err: &errors2.Error{
Code: EFailedToGetStorageHost,
Op: "cmd/fluxd.injectDeps",
Err: &errors2.Error{
Code: errors2.EInvalid,
Op: "cmd/fluxd.getStrList",
Err: &errors2.Error{
Code: errors2.EEmptyValue,
Err: errors.New("an err"),
},
},
},
encoded: `{"code":"failed to get the storage host","op":"cmd/fluxd.injectDeps","error":{"code":"invalid","op":"cmd/fluxd.getStrList","error":{"code":"empty value","error":"an err"}}}`,
},
}
for _, c := range cases {
result, err := json.Marshal(c.err)
// encode testing
if err != nil {
t.Errorf("%s encode failed, want err: %v, should be nil", c.name, err)
}
if string(result) != c.encoded {
t.Errorf("%s encode failed, want result: %s, got %s", c.name, c.encoded, string(result))
}
// decode testing
got := new(errors2.Error)
err = json.Unmarshal(result, got)
if err != nil {
t.Errorf("%s decode failed, want err: %v, should be nil", c.name, err)
}
decodeEqual(t, c.err, got, "decode: "+c.name)
}
}
func decodeEqual(t *testing.T, want, result *errors2.Error, caseName string) {
if want.Code != result.Code {
t.Errorf("%s code failed, want %s, got %s", caseName, want.Code, result.Code)
}
if want.Op != result.Op {
t.Errorf("%s op failed, want %s, got %s", caseName, want.Op, result.Op)
}
if want.Msg != result.Msg {
t.Errorf("%s msg failed, want %s, got %s", caseName, want.Msg, result.Msg)
}
if want.Err != nil {
if _, ok := want.Err.(*errors2.Error); ok {
decodeEqual(t, want.Err.(*errors2.Error), result.Err.(*errors2.Error), caseName)
} else {
if want.Err.Error() != result.Err.Error() {
t.Errorf("%s Err failed, want %s, got %s", caseName, want.Err.Error(), result.Err.Error())
}
}
}
}

145
kit/platform/id.go Normal file
View File

@ -0,0 +1,145 @@
package platform
import (
"encoding/binary"
"encoding/hex"
"strconv"
"unsafe"
"github.com/influxdata/influxdb/kit/platform/errors"
)
// IDLength is the exact length a string (or a byte slice representing it) must have in order to be decoded into a valid ID.
const IDLength = 16
var (
// ErrInvalidID signifies invalid IDs.
ErrInvalidID = &errors.Error{
Code: errors.EInvalid,
Msg: "invalid ID",
}
// ErrInvalidIDLength is returned when an ID has the incorrect number of bytes.
ErrInvalidIDLength = &errors.Error{
Code: errors.EInvalid,
Msg: "id must have a length of 16 bytes",
}
)
// ErrCorruptID means the ID stored in the Store is corrupt.
func ErrCorruptID(err error) *errors.Error {
return &errors.Error{
Code: errors.EInvalid,
Msg: "corrupt ID provided",
Err: err,
}
}
// ID is a unique identifier.
//
// Its zero value is not a valid ID.
type ID uint64
// IDGenerator represents a generator for IDs.
type IDGenerator interface {
// ID creates unique byte slice ID.
ID() ID
}
// IDFromString creates an ID from a given string.
//
// It errors if the input string does not match a valid ID.
func IDFromString(str string) (*ID, error) {
var id ID
err := id.DecodeFromString(str)
if err != nil {
return nil, err
}
return &id, nil
}
// InvalidID returns a zero ID.
func InvalidID() ID {
return 0
}
// Decode parses b as a hex-encoded byte-slice-string.
//
// It errors if the input byte slice does not have the correct length
// or if it contains all zeros.
func (i *ID) Decode(b []byte) error {
if len(b) != IDLength {
return ErrInvalidIDLength
}
res, err := strconv.ParseUint(unsafeBytesToString(b), 16, 64)
if err != nil {
return ErrInvalidID
}
if *i = ID(res); !i.Valid() {
return ErrInvalidID
}
return nil
}
func unsafeBytesToString(in []byte) string {
return *(*string)(unsafe.Pointer(&in))
}
// DecodeFromString parses s as a hex-encoded string.
func (i *ID) DecodeFromString(s string) error {
return i.Decode([]byte(s))
}
// Encode converts ID to a hex-encoded byte-slice-string.
//
// It errors if the receiving ID holds its zero value.
func (i ID) Encode() ([]byte, error) {
if !i.Valid() {
return nil, ErrInvalidID
}
b := make([]byte, hex.DecodedLen(IDLength))
binary.BigEndian.PutUint64(b, uint64(i))
dst := make([]byte, hex.EncodedLen(len(b)))
hex.Encode(dst, b)
return dst, nil
}
// Valid checks whether the receiving ID is a valid one or not.
func (i ID) Valid() bool {
return i != 0
}
// String returns the ID as a hex encoded string.
//
// Returns an empty string in the case the ID is invalid.
func (i ID) String() string {
enc, _ := i.Encode()
return string(enc)
}
// GoString formats the ID the same as the String method.
// Without this, when using the %#v verb, an ID would be printed as a uint64,
// so you would see e.g. 0x2def021097c6000 instead of 02def021097c6000
// (note the leading 0x, which means the former doesn't show up in searches for the latter).
func (i ID) GoString() string {
return `"` + i.String() + `"`
}
// MarshalText encodes i as text.
// Providing this method is a fallback for json.Marshal,
// with the added benefit that IDs encoded as map keys will be the expected string encoding,
// rather than the effective fmt.Sprintf("%d", i) that json.Marshal uses by default for integer types.
func (i ID) MarshalText() ([]byte, error) {
return i.Encode()
}
// UnmarshalText decodes i from a byte slice.
// Providing this method is also a fallback for json.Unmarshal,
// also relevant when IDs are used as map keys.
func (i *ID) UnmarshalText(b []byte) error {
return i.Decode(b)
}

259
kit/platform/id_test.go Normal file
View File

@ -0,0 +1,259 @@
package platform_test
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/influxdata/influxdb/kit/platform"
)
func MustIDBase16(s string) platform.ID {
id, err := platform.IDFromString(s)
if err != nil {
panic(err)
}
return *id
}
func TestIDFromString(t *testing.T) {
tests := []struct {
name string
id string
want platform.ID
wantErr bool
err string
}{
{
name: "Should be able to decode an all zeros ID",
id: "0000000000000000",
wantErr: true,
err: platform.ErrInvalidID.Error(),
},
{
name: "Should be able to decode an all f ID",
id: "ffffffffffffffff",
want: MustIDBase16("ffffffffffffffff"),
},
{
name: "Should be able to decode an ID",
id: "020f755c3c082000",
want: MustIDBase16("020f755c3c082000"),
},
{
name: "Should not be able to decode a non hex ID",
id: "gggggggggggggggg",
wantErr: true,
err: platform.ErrInvalidID.Error(),
},
{
name: "Should not be able to decode inputs with length less than 16 bytes",
id: "abc",
wantErr: true,
err: platform.ErrInvalidIDLength.Error(),
},
{
name: "Should not be able to decode inputs with length greater than 16 bytes",
id: "abcdabcdabcdabcd0",
wantErr: true,
err: platform.ErrInvalidIDLength.Error(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := platform.IDFromString(tt.id)
// Check negative test cases
if (err != nil) && tt.wantErr {
if tt.err != err.Error() {
t.Errorf("IDFromString() errors out \"%s\", want \"%s\"", err, tt.err)
}
return
}
// Check positive test cases
if !reflect.DeepEqual(*got, tt.want) && !tt.wantErr {
t.Errorf("IDFromString() outputs %v, want %v", got, tt.want)
}
})
}
}
func TestDecodeFromString(t *testing.T) {
var id platform.ID
err := id.DecodeFromString("020f755c3c082000")
if err != nil {
t.Errorf(err.Error())
}
want := []byte{48, 50, 48, 102, 55, 53, 53, 99, 51, 99, 48, 56, 50, 48, 48, 48}
got, _ := id.Encode()
if !bytes.Equal(want, got) {
t.Errorf("got %s not equal to wanted %s", string(got), string(want))
}
if id.String() != "020f755c3c082000" {
t.Errorf("expecting string representation to contain the right value")
}
if !id.Valid() {
t.Errorf("expecting ID to be a valid one")
}
}
func TestEncode(t *testing.T) {
var id platform.ID
if _, err := id.Encode(); err == nil {
t.Errorf("encoding an invalid ID should not be possible")
}
id.DecodeFromString("5ca1ab1eba5eba11")
want := []byte{53, 99, 97, 49, 97, 98, 49, 101, 98, 97, 53, 101, 98, 97, 49, 49}
got, _ := id.Encode()
if !bytes.Equal(want, got) {
t.Errorf("encoding error")
}
if id.String() != "5ca1ab1eba5eba11" {
t.Errorf("expecting string representation to contain the right value")
}
if !id.Valid() {
t.Errorf("expecting ID to be a valid one")
}
}
func TestDecodeFromAllZeros(t *testing.T) {
var id platform.ID
err := id.Decode(make([]byte, platform.IDLength))
if err == nil {
t.Errorf("expecting all zeros ID to not be a valid ID")
}
}
func TestDecodeFromShorterString(t *testing.T) {
var id platform.ID
err := id.DecodeFromString("020f75")
if err == nil {
t.Errorf("expecting shorter inputs to error")
}
if id.String() != "" {
t.Errorf("expecting invalid ID to be serialized into empty string")
}
}
func TestDecodeFromLongerString(t *testing.T) {
var id platform.ID
err := id.DecodeFromString("020f755c3c082000aaa")
if err == nil {
t.Errorf("expecting shorter inputs to error")
}
if id.String() != "" {
t.Errorf("expecting invalid ID to be serialized into empty string")
}
}
func TestDecodeFromEmptyString(t *testing.T) {
var id platform.ID
err := id.DecodeFromString("")
if err == nil {
t.Errorf("expecting empty inputs to error")
}
if id.String() != "" {
t.Errorf("expecting invalid ID to be serialized into empty string")
}
}
func TestMarshalling(t *testing.T) {
var id0 platform.ID
_, err := json.Marshal(id0)
if err == nil {
t.Errorf("expecting empty ID to not be a valid one")
}
init := "ca55e77eca55e77e"
id1, err := platform.IDFromString(init)
if err != nil {
t.Errorf(err.Error())
}
serialized, err := json.Marshal(id1)
if err != nil {
t.Errorf(err.Error())
}
var id2 platform.ID
json.Unmarshal(serialized, &id2)
bytes1, _ := id1.Encode()
bytes2, _ := id2.Encode()
if !bytes.Equal(bytes1, bytes2) {
t.Errorf("error marshalling/unmarshalling ID")
}
// When used as a map key, IDs must use their string encoding.
// If you only implement json.Marshaller, they will be encoded with Go's default integer encoding.
b, err := json.Marshal(map[platform.ID]int{0x1234: 5678})
if err != nil {
t.Error(err)
}
const exp = `{"0000000000001234":5678}`
if string(b) != exp {
t.Errorf("expected map to json.Marshal as %s; got %s", exp, string(b))
}
var idMap map[platform.ID]int
if err := json.Unmarshal(b, &idMap); err != nil {
t.Error(err)
}
if len(idMap) != 1 {
t.Errorf("expected length 1, got %d", len(idMap))
}
if idMap[0x1234] != 5678 {
t.Errorf("unmarshalled incorrectly; exp 0x1234:5678, got %v", idMap)
}
}
func TestValid(t *testing.T) {
var id platform.ID
if id.Valid() {
t.Errorf("expecting initial ID to be invalid")
}
if platform.InvalidID() != 0 {
t.Errorf("expecting invalid ID to return a zero ID, thus invalid")
}
}
func TestID_GoString(t *testing.T) {
type idGoStringTester struct {
ID platform.ID
}
var x idGoStringTester
const idString = "02def021097c6000"
if err := x.ID.DecodeFromString(idString); err != nil {
t.Fatal(err)
}
sharpV := fmt.Sprintf("%#v", x)
want := `platform_test.idGoStringTester{ID:"` + idString + `"}`
if sharpV != want {
t.Fatalf("bad GoString: got %q, want %q", sharpV, want)
}
}
func BenchmarkIDEncode(b *testing.B) {
var id platform.ID
id.DecodeFromString("5ca1ab1eba5eba11")
b.ResetTimer()
for i := 0; i < b.N; i++ {
b, _ := id.Encode()
_ = b
}
}
func BenchmarkIDDecode(b *testing.B) {
for i := 0; i < b.N; i++ {
var id platform.ID
id.DecodeFromString("5ca1ab1eba5eba11")
}
}

87
kit/prom/example_test.go Normal file
View File

@ -0,0 +1,87 @@
package prom_test
import (
"fmt"
"io"
"math/rand"
"net/http"
"time"
"github.com/influxdata/influxdb/kit/prom"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
)
// RandomHandler implements an HTTP endpoint that prints a random float,
// and it tracks prometheus metrics about the numbers it returns.
type RandomHandler struct {
// Cumulative sum of values served.
valueCounter prometheus.Counter
// Total times page served.
serveCounter prometheus.Counter
}
var (
_ http.Handler = (*RandomHandler)(nil)
_ prom.PrometheusCollector = (*RandomHandler)(nil)
)
func NewRandomHandler() *RandomHandler {
return &RandomHandler{
valueCounter: prometheus.NewCounter(prometheus.CounterOpts{
Name: "value_counter",
Help: "Cumulative sum of values served.",
}),
serveCounter: prometheus.NewCounter(prometheus.CounterOpts{
Name: "serve_counter",
Help: "Counter of times page has been served.",
}),
}
}
// ServeHTTP serves a random float value and updates rh's internal metrics.
func (rh *RandomHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Increment serveCounter every time we serve a page.
rh.serveCounter.Inc()
n := rand.Float64()
// Track the cumulative values served.
rh.valueCounter.Add(n)
fmt.Fprintf(w, "%v", n)
}
// PrometheusCollectors implements prom.PrometheusCollector.
func (rh *RandomHandler) PrometheusCollectors() []prometheus.Collector {
return []prometheus.Collector{rh.valueCounter, rh.serveCounter}
}
func Example() {
// A collection of endpoints and http.Handlers.
handlers := map[string]http.Handler{
"/random": NewRandomHandler(),
"/time": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, time.Now().String())
}),
}
// Use a local registry, not the global registry in the prometheus package.
reg := prom.NewRegistry(zap.NewNop())
// Build the mux out of handlers from above.
mux := http.NewServeMux()
for path, h := range handlers {
mux.Handle(path, h)
// Only register those handlers which implement prom.PrometheusCollector.
if pc, ok := h.(prom.PrometheusCollector); ok {
reg.MustRegister(pc.PrometheusCollectors()...)
}
}
// Add metrics to registry.
mux.Handle("/metrics", reg.HTTPHandler())
http.ListenAndServe("localhost:8080", mux)
}

View File

@ -0,0 +1,146 @@
// Package promtest provides helpers for parsing and extracting prometheus metrics.
// These functions are only intended to be called from test files,
// as there is a dependency on the standard library testing package.
package promtest
import (
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
)
// FromHTTPResponse parses the prometheus metrics from the given *http.Response.
// It relies on properly set response headers to correctly parse.
// It will unconditionally close the response body.
//
// This is particularly helpful when testing the output of the /metrics endpoint of a service.
// However, for comprehensive testing of metrics, it usually makes more sense to
// add collectors to a registry and call Registry.Gather to get the metrics without involving HTTP.
func FromHTTPResponse(r *http.Response) ([]*dto.MetricFamily, error) {
defer r.Body.Close()
dec := expfmt.NewDecoder(r.Body, expfmt.ResponseFormat(r.Header))
var mfs []*dto.MetricFamily
for {
mf := new(dto.MetricFamily)
if err := dec.Decode(mf); err != nil {
if err == io.EOF {
break
} else {
return nil, err
}
}
mfs = append(mfs, mf)
}
return mfs, nil
}
// FindMetric iterates through mfs to find the first metric family matching name.
// If a metric family matches, then the metrics inside the family are searched,
// and the first metric whose labels match the given labels are returned.
// If no matches are found, FindMetric returns nil.
//
// FindMetric assumes that the labels on the metric family are well formed,
// i.e. there are no duplicate label names, and the label values are not empty strings.
func FindMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric {
_, m := findMetric(mfs, name, labels)
return m
}
// MustFindMetric returns the matching metric, or if no matching metric could be found,
// it calls tb.Log with helpful output of what was actually available, before calling tb.FailNow.
func MustFindMetric(tb testing.TB, mfs []*dto.MetricFamily, name string, labels map[string]string) *dto.Metric {
tb.Helper()
fam, m := findMetric(mfs, name, labels)
if fam == nil {
tb.Logf("metric family with name %q not found", name)
tb.Log("available names:")
for _, mf := range mfs {
tb.Logf("\t%s", mf.GetName())
}
tb.FailNow()
return nil // Need an explicit return here for test.
}
if m == nil {
tb.Logf("found metric family with name %q, but metric with labels %v not found", name, labels)
tb.Logf("available labels on metric family %q:", name)
for _, m := range fam.Metric {
pairs := make([]string, len(m.Label))
for i, l := range m.Label {
pairs[i] = fmt.Sprintf("%q: %q", l.GetName(), l.GetValue())
}
tb.Logf("\t%s", strings.Join(pairs, ", "))
}
tb.FailNow()
return nil // Need an explicit return here for test.
}
return m
}
// findMetric is a helper that returns the matching family and the matching metric.
// The exported FindMetric function specifically only finds the metric, not the family,
// but for test it is more helpful to identify whether the family was matched.
func findMetric(mfs []*dto.MetricFamily, name string, labels map[string]string) (*dto.MetricFamily, *dto.Metric) {
var fam *dto.MetricFamily
for _, mf := range mfs {
if mf.GetName() == name {
fam = mf
break
}
}
if fam == nil {
// No family matching the name.
return nil, nil
}
for _, m := range fam.Metric {
if len(m.Label) != len(labels) {
continue
}
match := true
for _, l := range m.Label {
if labels[l.GetName()] != l.GetValue() {
match = false
break
}
}
if !match {
continue
}
// All labels matched.
return fam, m
}
// Didn't find a metric whose labels all matched.
return fam, nil
}
// MustGather calls g.Gather and calls tb.Fatal if there was an error.
func MustGather(tb testing.TB, g prometheus.Gatherer) []*dto.MetricFamily {
tb.Helper()
mfs, err := g.Gather()
if err != nil {
tb.Fatalf("error while gathering metrics: %v", err)
return nil // Need an explicit return here for test.
}
return mfs
}

View File

@ -0,0 +1,210 @@
package promtest_test
import (
"bytes"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/influxdata/influxdb/kit/prom"
"github.com/influxdata/influxdb/kit/prom/promtest"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"go.uber.org/zap/zaptest"
)
func helperCollectors() []prometheus.Collector {
myCounter := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "my",
Subsystem: "random",
Name: "counter",
Help: "Just a random counter.",
})
myGaugeVec := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "my",
Subsystem: "random",
Name: "gaugevec",
Help: "Just a random gauge vector.",
}, []string{"label1", "label2"})
myCounter.Inc()
myGaugeVec.WithLabelValues("one", "two").Set(3)
return []prometheus.Collector{myCounter, myGaugeVec}
}
func TestFindMetric(t *testing.T) {
reg := prom.NewRegistry(zaptest.NewLogger(t))
reg.MustRegister(helperCollectors()...)
mfs, err := reg.Gather()
if err != nil {
t.Fatal(err)
}
c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil)
if got := c.GetCounter().GetValue(); got != 1 {
t.Fatalf("expected counter to be 1, got %v", got)
}
g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"})
if got := g.GetGauge().GetValue(); got != 3 {
t.Fatalf("expected gauge to be 3, got %v", got)
}
}
// fakeT helps us to assert that MustFindMetric calls FailNow when the metric isn't found.
type fakeT struct {
// Embed a T so we don't have to reimplement everything.
// It's fine to leave T nil - fakeT will panic if calling a method we haven't implemented.
*testing.T
logBuf bytes.Buffer
failed bool
}
func (t *fakeT) Helper() {}
func (t *fakeT) Log(args ...interface{}) {
fmt.Fprint(&t.logBuf, args...)
t.logBuf.WriteString("\n")
}
func (t *fakeT) Logf(format string, args ...interface{}) {
fmt.Fprintf(&t.logBuf, format, args...)
t.logBuf.WriteString("\n")
}
func (t *fakeT) FailNow() {
t.failed = true
}
func (t *fakeT) Fatalf(format string, args ...interface{}) {
t.Logf(format, args...)
t.FailNow()
}
func TestMustFindMetric(t *testing.T) {
reg := prom.NewRegistry(zaptest.NewLogger(t))
reg.MustRegister(helperCollectors()...)
mfs, err := reg.Gather()
if err != nil {
t.Fatal(err)
}
ft := new(fakeT)
// Doesn't log when metric is found.
_ = promtest.MustFindMetric(ft, mfs, "my_random_counter", nil)
if ft.failed {
t.Fatalf("MustFindMetric failed when it should not have. message: %s", ft.logBuf.String())
}
// Logs and fails when family name not found.
ft = new(fakeT)
_ = promtest.MustFindMetric(ft, mfs, "missing_name", nil)
if !ft.failed {
t.Fatal("MustFindMetric should have failed but didn't")
}
logged := ft.logBuf.String()
if !strings.Contains(logged, `name "missing_name" not found`) {
t.Fatalf("did not log the looked up name which was not found. message: %s", logged)
}
if !strings.Contains(logged, "my_random_counter") || !strings.Contains(logged, "my_random_gaugevec") {
t.Fatalf("did not log the available metric names. message: %s", logged)
}
// Logs and fails when family name found but metric labels mismatch.
ft = new(fakeT)
_ = promtest.MustFindMetric(ft, mfs, "my_random_counter", map[string]string{"unknown": "label"})
if !ft.failed {
t.Fatal("MustFindMetric should have failed but didn't")
}
ft = new(fakeT)
_ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"unknown": "label"})
if !ft.failed {
t.Fatal("MustFindMetric should have failed but didn't")
}
logged = ft.logBuf.String()
if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) {
t.Fatalf("did not log the available label names. message: %s", logged)
}
ft = new(fakeT)
_ = promtest.MustFindMetric(ft, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two", "label3": "imaginary"})
if !ft.failed {
t.Fatal("MustFindMetric should have failed but didn't")
}
logged = ft.logBuf.String()
if !strings.Contains(logged, `"label1": "one"`) || !strings.Contains(logged, `"label2": "two"`) {
t.Fatalf("did not log the available label names. message: %s", logged)
}
}
func TestMustGather(t *testing.T) {
expErr := errors.New("failed to gather")
g := prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
return nil, expErr
})
ft := new(fakeT)
_ = promtest.MustGather(ft, g)
if !ft.failed {
t.Fatal("MustGather should have failed but didn't")
}
logged := ft.logBuf.String()
if !strings.HasPrefix(logged, "error while gathering metrics:") || !strings.Contains(logged, expErr.Error()) {
t.Fatalf("did not log the expected error message: %s", logged)
}
expMF := []*dto.MetricFamily{} // Use a non-nil, zero-length slice for a simple-ish check.
g = prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
return expMF, nil
})
ft = new(fakeT)
gotMF := promtest.MustGather(ft, g)
if ft.failed {
t.Fatalf("MustGather should not have failed")
}
if gotMF == nil || len(gotMF) != 0 {
t.Fatalf("exp: %v, got: %v", expMF, gotMF)
}
}
func TestFromHTTPResponse(t *testing.T) {
reg := prom.NewRegistry(zaptest.NewLogger(t))
reg.MustRegister(helperCollectors()...)
s := httptest.NewServer(reg.HTTPHandler())
defer s.Close()
resp, err := http.Get(s.URL) // Didn't specify a path for the handler, so any path should be fine.
if err != nil {
t.Fatal(err)
}
mfs, err := promtest.FromHTTPResponse(resp)
if err != nil {
t.Fatal(err)
}
if len(mfs) != 2 {
t.Fatalf("expected 2 metrics but got %d", len(mfs))
}
c := promtest.MustFindMetric(t, mfs, "my_random_counter", nil)
if got := c.GetCounter().GetValue(); got != 1 {
t.Fatalf("expected counter to be 1, got %v", got)
}
g := promtest.MustFindMetric(t, mfs, "my_random_gaugevec", map[string]string{"label1": "one", "label2": "two"})
if got := g.GetGauge().GetValue(); got != 3 {
t.Fatalf("expected gauge to be 3, got %v", got)
}
}

59
kit/prom/registry.go Normal file
View File

@ -0,0 +1,59 @@
// Package prom provides a wrapper around a prometheus metrics registry
// so that all services are unified in how they expose prometheus metrics.
package prom
import (
"net/http"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
)
// PrometheusCollector is the interface for a type to expose prometheus metrics.
// This interface is provided as a convention, so that you can optionally check
// if a type implements it and then pass its collectors to (*Registry).MustRegister.
type PrometheusCollector interface {
// PrometheusCollectors returns a slice of prometheus collectors
// containing metrics for the underlying instance.
PrometheusCollectors() []prometheus.Collector
}
// Registry embeds a prometheus registry and adds a couple convenience methods.
type Registry struct {
*prometheus.Registry
log *zap.Logger
}
// NewRegistry returns a new registry.
func NewRegistry(log *zap.Logger) *Registry {
return &Registry{
Registry: prometheus.NewRegistry(),
log: log,
}
}
// HTTPHandler returns an http.Handler for the registry,
// so that the /metrics HTTP handler is uniformly configured across all apps in the platform.
func (r *Registry) HTTPHandler() http.Handler {
opts := promhttp.HandlerOpts{
ErrorLog: promLogger{r: r},
// TODO(mr): decide if we want to set MaxRequestsInFlight or Timeout.
}
return promhttp.HandlerFor(r.Registry, opts)
}
// promLogger satisfies the promhttp.logger interface with the registry.
// Because normal usage is that WithLogger is called after HTTPHandler,
// we refer to the Registry rather than its logger.
type promLogger struct {
r *Registry
}
var _ promhttp.Logger = (*promLogger)(nil)
// Println implements promhttp.logger.
func (pl promLogger) Println(v ...interface{}) {
pl.r.log.Sugar().Info(v...)
}

60
kit/prom/registry_test.go Normal file
View File

@ -0,0 +1,60 @@
package prom_test
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/influxdata/influxdb/kit/prom"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
)
func TestRegistry_Logger(t *testing.T) {
core, logs := observer.New(zap.DebugLevel)
reg := prom.NewRegistry(zap.New(core))
// Normal use: HTTP handler is created immediately...
s := httptest.NewServer(reg.HTTPHandler())
defer s.Close()
// Force an error with a fake collector.
reg.MustRegister(errorCollector{})
resp, err := http.Get(s.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
foundLog := false
for _, le := range logs.All() {
if strings.Contains(le.Message, "invalid metric from errorCollector") {
foundLog = true
break
}
}
if !foundLog {
t.Fatalf("registry logger did not log error from metric collection")
}
}
type errorCollector struct{}
var _ prometheus.Collector = errorCollector{}
var ecDesc = prometheus.NewDesc("error_collector_desc", "A required description for the error collector", nil, nil)
func (errorCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- ecDesc
}
func (errorCollector) Collect(ch chan<- prometheus.Metric) {
ch <- prometheus.NewInvalidMetric(
ecDesc,
errors.New("invalid metric from errorCollector"),
)
}

31
kit/signals/context.go Normal file
View File

@ -0,0 +1,31 @@
package signals
import (
"context"
"os"
"os/signal"
"syscall"
)
// WithSignals returns a context that is canceled with any signal in sigs.
func WithSignals(ctx context.Context, sigs ...os.Signal) context.Context {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, sigs...)
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
select {
case <-ctx.Done():
return
case <-sigCh:
return
}
}()
return ctx
}
// WithStandardSignals cancels the context on os.Interrupt, syscall.SIGTERM.
func WithStandardSignals(ctx context.Context) context.Context {
return WithSignals(ctx, os.Interrupt, syscall.SIGTERM)
}

View File

@ -0,0 +1,79 @@
package signals
import (
"context"
"fmt"
"os"
"syscall"
"testing"
"time"
)
func ExampleWithSignals() {
ctx := WithSignals(context.Background(), syscall.SIGUSR1)
go func() {
time.Sleep(500 * time.Millisecond) // after some time SIGUSR1 is sent
// mimicking a signal from the outside
syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
}()
<-ctx.Done()
fmt.Println("finished")
// Output:
// finished
}
func Example_withUnregisteredSignals() {
dctx, cancel := context.WithTimeout(context.TODO(), time.Millisecond*100)
defer cancel()
ctx := WithSignals(dctx, syscall.SIGUSR1)
go func() {
time.Sleep(10 * time.Millisecond) // after some time SIGUSR2 is sent
// mimicking a signal from the outside, WithSignals will not handle it
syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
}()
<-ctx.Done()
fmt.Println("finished")
// Output:
// finished
}
func TestWithSignals(t *testing.T) {
tests := []struct {
name string
ctx context.Context
sigs []os.Signal
wantSignal bool
}{
{
name: "sending signal SIGUSR2 should exit context.",
ctx: context.Background(),
sigs: []os.Signal{syscall.SIGUSR2},
wantSignal: true,
},
{
name: "sending signal SIGUSR2 should NOT exit context.",
ctx: context.Background(),
sigs: []os.Signal{syscall.SIGUSR1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := WithSignals(tt.ctx, tt.sigs...)
syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-ctx.Done():
if !tt.wantSignal {
t.Errorf("unexpected exit with signal")
}
case <-timer.C:
if tt.wantSignal {
t.Errorf("expected to exit with signal but did not")
}
}
})
}
}

View File

@ -0,0 +1,24 @@
package testing
import (
"github.com/opentracing/opentracing-go"
"github.com/uber/jaeger-client-go"
)
// SetupInMemoryTracing sets the global tracer to an in memory Jaeger instance for testing.
// The returned function should be deferred by the caller to tear down this setup after testing is complete.
func SetupInMemoryTracing(name string) func() {
var (
old = opentracing.GlobalTracer()
tracer, closer = jaeger.NewTracer(name,
jaeger.NewConstSampler(true),
jaeger.NewInMemoryReporter(),
)
)
opentracing.SetGlobalTracer(tracer)
return func() {
_ = closer.Close()
opentracing.SetGlobalTracer(old)
}
}

225
kit/tracing/tracing.go Normal file
View File

@ -0,0 +1,225 @@
package tracing
import (
"context"
"errors"
"net/http"
"runtime"
"strings"
"time"
"github.com/go-chi/chi"
"github.com/prometheus/client_golang/prometheus"
"github.com/uber/jaeger-client-go"
"github.com/influxdata/httprouter"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/log"
)
// LogError adds a span log for an error.
// Returns unchanged error, so useful to wrap as in:
// return 0, tracing.LogError(err)
func LogError(span opentracing.Span, err error) error {
if err == nil {
return nil
}
// Get caller frame.
var pcs [1]uintptr
n := runtime.Callers(2, pcs[:])
if n < 1 {
span.LogFields(log.Error(err))
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
return err
}
file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0])
span.LogFields(log.String("filename", file), log.Int("line", line), log.Error(err))
return err
}
// InjectToHTTPRequest adds tracing headers to an HTTP request.
// Easier than adding this boilerplate everywhere.
func InjectToHTTPRequest(span opentracing.Span, req *http.Request) {
err := opentracing.GlobalTracer().Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
if err != nil {
LogError(span, err)
}
}
// ExtractFromHTTPRequest gets a child span of the parent referenced in HTTP request headers.
// Returns the request with updated tracing context.
// Easier than adding this boilerplate everywhere.
func ExtractFromHTTPRequest(req *http.Request, handlerName string) (opentracing.Span, *http.Request) {
spanContext, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
if err != nil {
span, ctx := opentracing.StartSpanFromContext(req.Context(), "request")
annotateSpan(span, handlerName, req)
_ = LogError(span, err)
return span, req.WithContext(ctx)
}
span := opentracing.StartSpan("request", opentracing.ChildOf(spanContext), ext.RPCServerOption(spanContext))
annotateSpan(span, handlerName, req)
return span, req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
}
func annotateSpan(span opentracing.Span, handlerName string, req *http.Request) {
if route := httprouter.MatchedRouteFromContext(req.Context()); route != "" {
span.SetTag("route", route)
}
span.SetTag("method", req.Method)
if ctx := chi.RouteContext(req.Context()); ctx != nil {
span.SetTag("route", ctx.RoutePath)
}
span.SetTag("handler", handlerName)
span.LogKV("path", req.URL.Path)
}
// span is a simple wrapper around opentracing.Span in order to
// get access to the duration of the span for metrics reporting.
type Span struct {
opentracing.Span
start time.Time
Duration time.Duration
hist prometheus.Observer
gauge prometheus.Gauge
}
func StartSpanFromContextWithPromMetrics(ctx context.Context, operationName string, hist prometheus.Observer, gauge prometheus.Gauge, opts ...opentracing.StartSpanOption) (*Span, context.Context) {
start := time.Now()
s, sctx := StartSpanFromContextWithOperationName(ctx, operationName, opentracing.StartTime(start))
gauge.Inc()
return &Span{s, start, 0, hist, gauge}, sctx
}
func (s *Span) Finish() {
finish := time.Now()
s.Duration = finish.Sub(s.start)
s.Span.FinishWithOptions(opentracing.FinishOptions{
FinishTime: finish,
})
s.hist.Observe(s.Duration.Seconds())
s.gauge.Dec()
}
// StartSpanFromContext is an improved opentracing.StartSpanFromContext.
// Uses the calling function as the operation name, and logs the filename and line number.
//
// Passing nil context induces panic.
// Context without parent span reference triggers root span construction.
// This function never returns nil values.
//
// Performance
//
// This function incurs a small performance penalty, roughly 1000 ns/op, 376 B/op, 6 allocs/op.
// Jaeger timestamp and duration precision is only µs, so this is pretty negligible.
//
// Alternatives
//
// If this performance penalty is too much, try these, which are also demonstrated in benchmark tests:
// // Create a root span
// span := opentracing.StartSpan("operation name")
// ctx := opentracing.ContextWithSpan(context.Background(), span)
//
// // Create a child span
// span := opentracing.StartSpan("operation name", opentracing.ChildOf(sc))
// ctx := opentracing.ContextWithSpan(context.Background(), span)
//
// // Sugar to create a child span
// span, ctx := opentracing.StartSpanFromContext(ctx, "operation name")
func StartSpanFromContext(ctx context.Context, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
if ctx == nil {
panic("StartSpanFromContext called with nil context")
}
// Get caller frame.
var pcs [1]uintptr
n := runtime.Callers(2, pcs[:])
if n < 1 {
span, ctx := opentracing.StartSpanFromContext(ctx, "unknown", opts...)
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
return span, ctx
}
fn := runtime.FuncForPC(pcs[0])
name := fn.Name()
if lastSlash := strings.LastIndexByte(name, '/'); lastSlash > 0 {
name = name[lastSlash+1:]
}
var span opentracing.Span
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
// Create a child span.
opts = append(opts, opentracing.ChildOf(parentSpan.Context()))
span = opentracing.StartSpan(name, opts...)
} else {
// Create a root span.
span = opentracing.StartSpan(name)
}
// New context references this span, not the parent (if there was one).
ctx = opentracing.ContextWithSpan(ctx, span)
file, line := fn.FileLine(pcs[0])
span.LogFields(log.String("filename", file), log.Int("line", line))
return span, ctx
}
// StartSpanFromContextWithOperationName is like StartSpanFromContext, but the caller determines the operation name.
func StartSpanFromContextWithOperationName(ctx context.Context, operationName string, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
if ctx == nil {
panic("StartSpanFromContextWithOperationName called with nil context")
}
// Get caller frame.
var pcs [1]uintptr
n := runtime.Callers(2, pcs[:])
if n < 1 {
span, ctx := opentracing.StartSpanFromContext(ctx, operationName, opts...)
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
return span, ctx
}
file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0])
var span opentracing.Span
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
opts = append(opts, opentracing.ChildOf(parentSpan.Context()))
// Create a child span.
span = opentracing.StartSpan(operationName, opts...)
} else {
// Create a root span.
span = opentracing.StartSpan(operationName, opts...)
}
// New context references this span, not the parent (if there was one).
ctx = opentracing.ContextWithSpan(ctx, span)
span.LogFields(log.String("filename", file), log.Int("line", line))
return span, ctx
}
// InfoFromSpan returns the traceID and if it was sampled from the span, given it is a jaeger span.
// It returns whether a span associated to the context has been found.
func InfoFromSpan(span opentracing.Span) (traceID string, sampled bool, found bool) {
if spanContext, ok := span.Context().(jaeger.SpanContext); ok {
traceID = spanContext.TraceID().String()
sampled = spanContext.IsSampled()
return traceID, sampled, true
}
return "", false, false
}
// InfoFromContext returns the traceID and if it was sampled from the Jaeger span
// found in the given context. It returns whether a span associated to the context has been found.
func InfoFromContext(ctx context.Context) (traceID string, sampled bool, found bool) {
if span := opentracing.SpanFromContext(ctx); span != nil {
return InfoFromSpan(span)
}
return "", false, false
}

333
kit/tracing/tracing_test.go Normal file
View File

@ -0,0 +1,333 @@
package tracing
import (
"context"
"fmt"
"net/http"
"net/url"
"runtime"
"testing"
"github.com/go-chi/chi"
"github.com/influxdata/httprouter"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
)
func TestInjectAndExtractHTTPRequest(t *testing.T) {
tracer := mocktracer.New()
oldTracer := opentracing.GlobalTracer()
opentracing.SetGlobalTracer(tracer)
defer opentracing.SetGlobalTracer(oldTracer)
request, err := http.NewRequest(http.MethodPost, "http://localhost/", nil)
if err != nil {
t.Fatal(err)
}
span := tracer.StartSpan("operation name")
InjectToHTTPRequest(span, request)
gotSpan, _ := ExtractFromHTTPRequest(request, "MyStruct")
if span.(*mocktracer.MockSpan).SpanContext.TraceID != gotSpan.(*mocktracer.MockSpan).SpanContext.TraceID {
t.Error("injected and extracted traceIDs not equal")
}
if span.(*mocktracer.MockSpan).SpanContext.SpanID != gotSpan.(*mocktracer.MockSpan).ParentID {
t.Error("injected span ID does not match extracted span parent ID")
}
}
func TestExtractHTTPRequest(t *testing.T) {
var (
tracer = mocktracer.New()
oldTracer = opentracing.GlobalTracer()
ctx = context.Background()
)
opentracing.SetGlobalTracer(tracer)
defer opentracing.SetGlobalTracer(oldTracer)
for _, test := range []struct {
name string
handlerName string
path string
ctx context.Context
tags map[string]interface{}
method string
}{
{
name: "happy path",
handlerName: "WriteHandler",
ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/write"),
method: http.MethodGet,
path: "/api/v2/write",
tags: map[string]interface{}{
"route": "/api/v2/write",
"handler": "WriteHandler",
},
},
{
name: "happy path bucket handler",
handlerName: "BucketHandler",
ctx: context.WithValue(ctx, httprouter.MatchedRouteKey, "/api/v2/buckets/:bucket_id"),
path: "/api/v2/buckets/12345",
method: http.MethodGet,
tags: map[string]interface{}{
"route": "/api/v2/buckets/:bucket_id",
"handler": "BucketHandler",
},
},
{
name: "happy path bucket handler (chi)",
handlerName: "BucketHandler",
ctx: context.WithValue(
ctx,
chi.RouteCtxKey,
&chi.Context{RoutePath: "/api/v2/buckets/:bucket_id", RouteMethod: "GET"},
),
path: "/api/v2/buckets/12345",
method: http.MethodGet,
tags: map[string]interface{}{
"route": "/api/v2/buckets/:bucket_id",
"method": "GET",
"handler": "BucketHandler",
},
},
{
name: "empty path",
handlerName: "Home",
ctx: ctx,
method: http.MethodGet,
tags: map[string]interface{}{
"handler": "Home",
},
},
} {
t.Run(test.name, func(t *testing.T) {
request, err := http.NewRequest(test.method, "http://localhost"+test.path, nil)
if err != nil {
t.Fatal(err)
}
span := tracer.StartSpan("operation name")
InjectToHTTPRequest(span, request)
gotSpan, _ := ExtractFromHTTPRequest(request.WithContext(test.ctx), test.handlerName)
if op := gotSpan.(*mocktracer.MockSpan).OperationName; op != "request" {
t.Fatalf("operation name %q != request", op)
}
tags := gotSpan.(*mocktracer.MockSpan).Tags()
for k, v := range test.tags {
found, ok := tags[k]
if !ok {
t.Errorf("tag not found in span %q", k)
continue
}
if found != v {
t.Errorf("expected %v, found %v for tag %q", v, found, k)
}
}
})
}
}
func TestStartSpanFromContext(t *testing.T) {
tracer := mocktracer.New()
oldTracer := opentracing.GlobalTracer()
opentracing.SetGlobalTracer(tracer)
defer opentracing.SetGlobalTracer(oldTracer)
type testCase struct {
ctx context.Context
expectPanic bool
expectParent bool
}
var testCases []testCase
testCases = append(testCases,
testCase{
ctx: nil,
expectPanic: true,
expectParent: false,
},
testCase{
ctx: context.Background(),
expectPanic: false,
expectParent: false,
})
parentSpan := opentracing.StartSpan("parent operation name")
testCases = append(testCases, testCase{
ctx: opentracing.ContextWithSpan(context.Background(), parentSpan),
expectPanic: false,
expectParent: true,
})
for i, tc := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
var span opentracing.Span
var ctx context.Context
var gotPanic bool
func(inputCtx context.Context) {
defer func() {
if recover() != nil {
gotPanic = true
}
}()
span, ctx = StartSpanFromContext(inputCtx)
}(tc.ctx)
if tc.expectPanic != gotPanic {
t.Errorf("panic: expect %v got %v", tc.expectPanic, gotPanic)
}
if tc.expectPanic {
// No other valid checks if panic.
return
}
if ctx == nil {
t.Error("never expect non-nil ctx")
}
if span == nil {
t.Error("never expect non-nil Span")
}
foundParent := span.(*mocktracer.MockSpan).ParentID != 0
if tc.expectParent != foundParent {
t.Errorf("parent: expect %v got %v", tc.expectParent, foundParent)
}
if ctx == tc.ctx {
t.Errorf("always expect fresh context")
}
})
}
}
func TestLogErrorNil(t *testing.T) {
tracer := mocktracer.New()
span := tracer.StartSpan("test").(*mocktracer.MockSpan)
var err error
if err2 := LogError(span, err); err2 != nil {
t.Errorf("expected nil err, got '%s'", err2.Error())
}
if len(span.Logs()) > 0 {
t.Errorf("expected zero new span logs, got %d", len(span.Logs()))
println(span.Logs()[0].Fields[0].Key)
}
}
/*
BenchmarkLocal_StartSpanFromContext-8 2000000 681 ns/op 224 B/op 4 allocs/op
BenchmarkLocal_StartSpanFromContext_runtimeCaller-8 3000000 534 ns/op
BenchmarkLocal_StartSpanFromContext_runtimeCallers-8 10000000 196 ns/op
BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC-8 200000000 7.28 ns/op
BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames-8 10000000 234 ns/op
BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine-8 20000000 103 ns/op
BenchmarkOpentracing_StartSpanFromContext-8 10000000 155 ns/op 96 B/op 3 allocs/op
BenchmarkOpentracing_StartSpan_root-8 200000000 7.68 ns/op 0 B/op 0 allocs/op
BenchmarkOpentracing_StartSpan_child-8 20000000 71.2 ns/op 48 B/op 2 allocs/op
*/
func BenchmarkLocal_StartSpanFromContext(b *testing.B) {
b.ReportAllocs()
parentSpan := opentracing.StartSpan("parent operation name")
ctx := opentracing.ContextWithSpan(context.Background(), parentSpan)
for n := 0; n < b.N; n++ {
StartSpanFromContext(ctx)
}
}
func BenchmarkLocal_StartSpanFromContext_runtimeCaller(b *testing.B) {
for n := 0; n < b.N; n++ {
_, _, _, _ = runtime.Caller(1)
}
}
func BenchmarkLocal_StartSpanFromContext_runtimeCallers(b *testing.B) {
var pcs [1]uintptr
for n := 0; n < b.N; n++ {
_ = runtime.Callers(2, pcs[:])
}
}
func BenchmarkLocal_StartSpanFromContext_runtimeFuncForPC(b *testing.B) {
var pcs [1]uintptr
_ = runtime.Callers(2, pcs[:])
for n := 0; n < b.N; n++ {
_ = runtime.FuncForPC(pcs[0])
}
}
func BenchmarkLocal_StartSpanFromContext_runtimeCallersFrames(b *testing.B) {
pc, _, _, ok := runtime.Caller(1)
if !ok {
b.Fatal("runtime.Caller failed")
}
for n := 0; n < b.N; n++ {
_, _ = runtime.CallersFrames([]uintptr{pc}).Next()
}
}
func BenchmarkLocal_StartSpanFromContext_runtimeFuncFileLine(b *testing.B) {
var pcs [1]uintptr
_ = runtime.Callers(2, pcs[:])
fn := runtime.FuncForPC(pcs[0])
for n := 0; n < b.N; n++ {
_, _ = fn.FileLine(pcs[0])
}
}
func BenchmarkOpentracing_StartSpanFromContext(b *testing.B) {
b.ReportAllocs()
parentSpan := opentracing.StartSpan("parent operation name")
ctx := opentracing.ContextWithSpan(context.Background(), parentSpan)
for n := 0; n < b.N; n++ {
_, _ = opentracing.StartSpanFromContext(ctx, "operation name")
}
}
func BenchmarkOpentracing_StartSpan_root(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_ = opentracing.StartSpan("operation name")
}
}
func BenchmarkOpentracing_StartSpan_child(b *testing.B) {
b.ReportAllocs()
parentSpan := opentracing.StartSpan("parent operation name")
for n := 0; n < b.N; n++ {
_ = opentracing.StartSpan("operation name", opentracing.ChildOf(parentSpan.Context()))
}
}
func BenchmarkOpentracing_ExtractFromHTTPRequest(b *testing.B) {
b.ReportAllocs()
req := &http.Request{
URL: &url.URL{Path: "/api/v2/organization/12345"},
}
for n := 0; n < b.N; n++ {
_, _ = ExtractFromHTTPRequest(req, "OrganizationHandler")
}
}

267
kit/transport/http/api.go Normal file
View File

@ -0,0 +1,267 @@
package http
import (
"compress/gzip"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/influxdata/influxdb/kit/platform/errors"
"go.uber.org/zap"
)
// PlatformErrorCodeHeader shows the error code of platform error.
const PlatformErrorCodeHeader = "X-Platform-Error-Code"
// API provides a consolidated means for handling API interface concerns.
// Concerns such as decoding/encoding request and response bodies as well
// as adding headers for content type and content encoding.
type API struct {
logger *zap.Logger
prettyJSON bool
encodeGZIP bool
unmarshalErrFn func(encoding string, err error) error
okErrFn func(err error) error
errFn func(ctx context.Context, err error) (interface{}, int, error)
}
// APIOptFn is a functional option for setting fields on the API type.
type APIOptFn func(*API)
// WithLog sets the logger.
func WithLog(logger *zap.Logger) APIOptFn {
return func(api *API) {
api.logger = logger
}
}
// WithErrFn sets the err handling func for issues when writing to the response body.
func WithErrFn(fn func(ctx context.Context, err error) (interface{}, int, error)) APIOptFn {
return func(api *API) {
api.errFn = fn
}
}
// WithOKErrFn is an error handler for failing validation for request bodies.
func WithOKErrFn(fn func(err error) error) APIOptFn {
return func(api *API) {
api.okErrFn = fn
}
}
// WithPrettyJSON sets the json encoder to marshal indent or not.
func WithPrettyJSON(b bool) APIOptFn {
return func(api *API) {
api.prettyJSON = b
}
}
// WithEncodeGZIP sets the encoder to gzip contents.
func WithEncodeGZIP() APIOptFn {
return func(api *API) {
api.encodeGZIP = true
}
}
// WithUnmarshalErrFn sets the error handler for errors that occur when unmarshalling
// the request body.
func WithUnmarshalErrFn(fn func(encoding string, err error) error) APIOptFn {
return func(api *API) {
api.unmarshalErrFn = fn
}
}
// NewAPI creates a new API type.
func NewAPI(opts ...APIOptFn) *API {
api := API{
logger: zap.NewNop(),
prettyJSON: true,
unmarshalErrFn: func(encoding string, err error) error {
return &errors.Error{
Code: errors.EInvalid,
Msg: fmt.Sprintf("failed to unmarshal %s: %s", encoding, err),
}
},
errFn: func(ctx context.Context, err error) (interface{}, int, error) {
msg := err.Error()
if msg == "" {
msg = "an internal error has occurred"
}
code := errors.ErrorCode(err)
return ErrBody{
Code: code,
Msg: msg,
}, ErrorCodeToStatusCode(ctx, code), nil
},
}
for _, o := range opts {
o(&api)
}
return &api
}
// DecodeJSON decodes reader with json.
func (a *API) DecodeJSON(r io.Reader, v interface{}) error {
return a.decode("json", json.NewDecoder(r), v)
}
// DecodeGob decodes reader with gob.
func (a *API) DecodeGob(r io.Reader, v interface{}) error {
return a.decode("gob", gob.NewDecoder(r), v)
}
type (
decoder interface {
Decode(interface{}) error
}
oker interface {
OK() error
}
)
func (a *API) decode(encoding string, dec decoder, v interface{}) error {
if err := dec.Decode(v); err != nil {
if a != nil && a.unmarshalErrFn != nil {
return a.unmarshalErrFn(encoding, err)
}
return err
}
if vv, ok := v.(oker); ok {
err := vv.OK()
if a != nil && a.okErrFn != nil {
return a.okErrFn(err)
}
return err
}
return nil
}
// Respond writes to the response writer, handling all errors in writing.
func (a *API) Respond(w http.ResponseWriter, r *http.Request, status int, v interface{}) {
if status == http.StatusNoContent {
w.WriteHeader(status)
return
}
var writer io.WriteCloser = noopCloser{Writer: w}
// we'll double close to make sure its always closed even
//on issues before the write
defer writer.Close()
if a != nil && a.encodeGZIP {
w.Header().Set("Content-Encoding", "gzip")
writer = gzip.NewWriter(w)
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
// this marshal block is to catch failures before they hit the http writer.
// default behavior for http.ResponseWriter is when body is written and no
// status is set, it writes a 200. Or if a status is set before encoding
// and an error occurs, there is no means to write a proper status code
// (i.e. 500) when that is to occur. This brings that step out before
// and then writes the data and sets the status code after marshaling
// succeeds.
var (
b []byte
err error
)
if a == nil || a.prettyJSON {
b, err = json.MarshalIndent(v, "", "\t")
} else {
b, err = json.Marshal(v)
}
if err != nil {
a.Err(w, r, err)
return
}
a.write(w, writer, status, b)
}
// Write allows the user to write raw bytes to the response writer. This
// operation does not have a fail case, all failures here will be logged.
func (a *API) Write(w http.ResponseWriter, status int, b []byte) {
if status == http.StatusNoContent {
w.WriteHeader(status)
return
}
var writer io.WriteCloser = noopCloser{Writer: w}
// we'll double close to make sure its always closed even
//on issues before the write
defer writer.Close()
if a != nil && a.encodeGZIP {
w.Header().Set("Content-Encoding", "gzip")
writer = gzip.NewWriter(w)
}
a.write(w, writer, status, b)
}
func (a *API) write(w http.ResponseWriter, wc io.WriteCloser, status int, b []byte) {
w.WriteHeader(status)
if _, err := wc.Write(b); err != nil {
a.logErr("failed to write to response writer", zap.Error(err))
}
if err := wc.Close(); err != nil {
a.logErr("failed to close response writer", zap.Error(err))
}
}
// Err is used for writing an error to the response.
func (a *API) Err(w http.ResponseWriter, r *http.Request, err error) {
if err == nil {
return
}
a.logErr("api error encountered", zap.Error(err))
v, status, err := a.errFn(r.Context(), err)
if err != nil {
a.logErr("failed to write err to response writer", zap.Error(err))
a.Respond(w, r, http.StatusInternalServerError, ErrBody{
Code: "internal error",
Msg: "an unexpected error occurred",
})
return
}
if eb, ok := v.(ErrBody); ok {
w.Header().Set(PlatformErrorCodeHeader, eb.Code)
}
a.Respond(w, r, status, v)
}
func (a *API) logErr(msg string, fields ...zap.Field) {
if a == nil || a.logger == nil {
return
}
a.logger.Error(msg, fields...)
}
type noopCloser struct {
io.Writer
}
func (n noopCloser) Close() error {
return nil
}
// ErrBody is an err response body.
type ErrBody struct {
Code string `json:"code"`
Msg string `json:"message"`
}

View File

@ -0,0 +1,309 @@
package http_test
import (
"bytes"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/influxdata/influxdb/kit/platform/errors"
kithttp "github.com/influxdata/influxdb/kit/transport/http"
"github.com/influxdata/influxdb/pkg/testttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_API(t *testing.T) {
t.Run("Decode", func(t *testing.T) {
t.Run("valid foo no errors", func(t *testing.T) {
expected := validatFoo{
Foo: "valid",
Bar: 10,
}
t.Run("json", func(t *testing.T) {
var api *kithttp.API // shows it is safe to use a nil value
var out validatFoo
err := api.DecodeJSON(encodeJSON(t, expected), &out)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if expected != out {
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
}
})
t.Run("gob", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeGob(encodeGob(t, expected), &out)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if expected != out {
t.Fatalf("unexpected vals:\n\texpected: %#v\n\tgot: %#v", expected, out)
}
})
})
t.Run("unmarshals fine with ok error", func(t *testing.T) {
badFoo := validatFoo{
Foo: "",
Bar: 0,
}
t.Run("json", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeJSON(encodeJSON(t, badFoo), &out)
if err == nil {
t.Fatal("expected an err")
}
})
t.Run("gob", func(t *testing.T) {
var out validatFoo
err := kithttp.NewAPI().DecodeGob(encodeGob(t, badFoo), &out)
if err == nil {
t.Fatal("expected an err")
}
})
})
t.Run("unmarshal error", func(t *testing.T) {
invalidBody := []byte("[}-{]")
var out validatFoo
err := kithttp.NewAPI().DecodeJSON(bytes.NewReader(invalidBody), &out)
if err == nil {
t.Fatal("expected an error")
}
})
t.Run("unmarshal err fn wraps unmarshalling error", func(t *testing.T) {
t.Run("json", func(t *testing.T) {
invalidBody := []byte("[}-{]")
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
var out validatFoo
err := api.DecodeJSON(bytes.NewReader(invalidBody), &out)
expectInfluxdbError(t, errors.EInvalid, err)
})
t.Run("gob", func(t *testing.T) {
invalidBody := []byte("[}-{]")
api := kithttp.NewAPI(kithttp.WithUnmarshalErrFn(unmarshalErrFn))
var out validatFoo
err := api.DecodeGob(bytes.NewReader(invalidBody), &out)
expectInfluxdbError(t, errors.EInvalid, err)
})
})
t.Run("ok error fn wraps ok error", func(t *testing.T) {
badFoo := validatFoo{Foo: ""}
t.Run("json", func(t *testing.T) {
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
var out validatFoo
err := api.DecodeJSON(encodeJSON(t, badFoo), &out)
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
})
t.Run("gob", func(t *testing.T) {
api := kithttp.NewAPI(kithttp.WithOKErrFn(okErrFn))
var out validatFoo
err := api.DecodeGob(encodeGob(t, badFoo), &out)
expectInfluxdbError(t, errors.EUnprocessableEntity, err)
})
})
})
t.Run("Respond", func(t *testing.T) {
tests := []int{
http.StatusCreated,
http.StatusOK,
http.StatusNoContent,
http.StatusForbidden,
http.StatusInternalServerError,
}
for _, statusCode := range tests {
fn := func(t *testing.T) {
responder := kithttp.NewAPI()
svr := func(w http.ResponseWriter, r *http.Request) {
responder.Respond(w, r, statusCode, map[string]string{
"foo": "bar",
})
}
expectedBodyFn := func(body *bytes.Buffer) {
var resp map[string]string
require.NoError(t, json.NewDecoder(body).Decode(&resp))
assert.Equal(t, "bar", resp["foo"])
}
if statusCode == http.StatusNoContent {
expectedBodyFn = func(body *bytes.Buffer) {
require.Zero(t, body.Len())
}
}
testttp.
Get(t, "/foo").
Do(http.HandlerFunc(svr)).
ExpectStatus(statusCode).
ExpectBody(expectedBodyFn)
}
t.Run(http.StatusText(statusCode), fn)
}
})
t.Run("Err", func(t *testing.T) {
tests := []struct {
statusCode int
expectedErr *errors.Error
}{
{
statusCode: http.StatusBadRequest,
expectedErr: &errors.Error{
Code: errors.EInvalid,
Msg: "failed to unmarshal",
},
},
{
statusCode: http.StatusForbidden,
expectedErr: &errors.Error{
Code: errors.EForbidden,
Msg: "forbidden",
},
},
{
statusCode: http.StatusUnprocessableEntity,
expectedErr: &errors.Error{
Code: errors.EUnprocessableEntity,
Msg: "failed validation",
},
},
{
statusCode: http.StatusInternalServerError,
expectedErr: &errors.Error{
Code: errors.EInternal,
Msg: "internal error",
},
},
}
for _, tt := range tests {
fn := func(t *testing.T) {
responder := kithttp.NewAPI()
svr := func(w http.ResponseWriter, r *http.Request) {
responder.Err(w, r, tt.expectedErr)
}
testttp.
Get(t, "/foo").
Do(http.HandlerFunc(svr)).
ExpectStatus(tt.statusCode).
ExpectBody(func(body *bytes.Buffer) {
var err kithttp.ErrBody
require.NoError(t, json.NewDecoder(body).Decode(&err))
assert.Equal(t, tt.expectedErr.Msg, err.Msg)
assert.Equal(t, tt.expectedErr.Code, err.Code)
})
}
t.Run(http.StatusText(tt.statusCode), fn)
}
})
}
func expectInfluxdbError(t *testing.T, expectedCode string, err error) {
t.Helper()
if err == nil {
t.Fatal("expected an error")
}
iErr, ok := err.(*errors.Error)
if !ok {
t.Fatalf("expected an influxdb error; got=%#v", err)
}
if got := iErr.Code; expectedCode != got {
t.Fatalf("unexpected error code; expected=%s got=%s", expectedCode, got)
}
}
func encodeGob(t *testing.T, v interface{}) io.Reader {
t.Helper()
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
t.Fatal(err)
}
return &buf
}
func encodeJSON(t *testing.T, v interface{}) io.Reader {
t.Helper()
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(v); err != nil {
t.Fatal(err)
}
return &buf
}
func okErrFn(err error) error {
return &errors.Error{
Code: errors.EUnprocessableEntity,
Msg: "failed validation",
Err: err,
}
}
func unmarshalErrFn(encoding string, err error) error {
return &errors.Error{
Code: errors.EInvalid,
Msg: fmt.Sprintf("invalid %s request body", encoding),
Err: err,
}
}
type validatFoo struct {
Foo string `gob:"foo"`
Bar int `gob:"bar"`
}
func (v *validatFoo) OK() error {
var errs multiErr
if v.Foo == "" {
errs = append(errs, "foo must be at least 1 char")
}
if v.Bar < 0 {
errs = append(errs, "bar must be a positive real number")
}
return errs.toError()
}
type multiErr []string
func (m multiErr) toError() error {
if len(m) > 0 {
return m
}
return nil
}
func (m multiErr) Error() string {
return strings.Join(m, "; ")
}

View File

@ -0,0 +1,188 @@
package http
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"strings"
errors2 "github.com/influxdata/influxdb/kit/platform/errors"
)
// ErrorHandler is the error handler in http package.
type ErrorHandler int
// HandleHTTPError encodes err with the appropriate status code and format,
// sets the X-Platform-Error-Code headers on the response.
// We're no longer using X-Influx-Error and X-Influx-Reference.
// and sets the response status to the corresponding status code.
func (h ErrorHandler) HandleHTTPError(ctx context.Context, err error, w http.ResponseWriter) {
if err == nil {
return
}
code := errors2.ErrorCode(err)
var msg string
if err, ok := err.(*errors2.Error); ok {
msg = err.Error()
} else {
msg = "An internal error has occurred"
}
WriteErrorResponse(ctx, w, code, msg)
}
func WriteErrorResponse(ctx context.Context, w http.ResponseWriter, code string, msg string) {
w.Header().Set(PlatformErrorCodeHeader, code)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(ErrorCodeToStatusCode(ctx, code))
e := struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: code,
Message: msg,
}
b, _ := json.Marshal(e)
_, _ = w.Write(b)
}
// StatusCodeToErrorCode maps a http status code integer to an
// influxdb error code string.
func StatusCodeToErrorCode(statusCode int) string {
errorCode, ok := httpStatusCodeToInfluxDBError[statusCode]
if ok {
return errorCode
}
return errors2.EInternal
}
// ErrorCodeToStatusCode maps an influxdb error code string to a
// http status code integer.
func ErrorCodeToStatusCode(ctx context.Context, code string) int {
// If the client disconnects early or times out then return a different
// error than the passed in error code. Client timeouts return a 408
// while disconnections return a non-standard Nginx HTTP 499 code.
if err := ctx.Err(); err == context.DeadlineExceeded {
return http.StatusRequestTimeout
} else if err == context.Canceled {
return 499 // https://httpstatuses.com/499
}
// Otherwise map internal error codes to HTTP status codes.
statusCode, ok := influxDBErrorToStatusCode[code]
if ok {
return statusCode
}
return http.StatusInternalServerError
}
// influxDBErrorToStatusCode is a mapping of ErrorCode to http status code.
var influxDBErrorToStatusCode = map[string]int{
errors2.EInternal: http.StatusInternalServerError,
errors2.ENotImplemented: http.StatusNotImplemented,
errors2.EInvalid: http.StatusBadRequest,
errors2.EUnprocessableEntity: http.StatusUnprocessableEntity,
errors2.EEmptyValue: http.StatusBadRequest,
errors2.EConflict: http.StatusUnprocessableEntity,
errors2.ENotFound: http.StatusNotFound,
errors2.EUnavailable: http.StatusServiceUnavailable,
errors2.EForbidden: http.StatusForbidden,
errors2.ETooManyRequests: http.StatusTooManyRequests,
errors2.EUnauthorized: http.StatusUnauthorized,
errors2.EMethodNotAllowed: http.StatusMethodNotAllowed,
errors2.ETooLarge: http.StatusRequestEntityTooLarge,
}
var httpStatusCodeToInfluxDBError = map[int]string{}
func init() {
for k, v := range influxDBErrorToStatusCode {
httpStatusCodeToInfluxDBError[v] = k
}
}
// CheckErrorStatus for status and any error in the response.
func CheckErrorStatus(code int, res *http.Response) error {
err := CheckError(res)
if err != nil {
return err
}
if res.StatusCode != code {
return fmt.Errorf("unexpected status code: %s", res.Status)
}
return nil
}
// CheckError reads the http.Response and returns an error if one exists.
// It will automatically recognize the errors returned by Influx services
// and decode the error into an internal error type. If the error cannot
// be determined in that way, it will create a generic error message.
//
// If there is no error, then this returns nil.
func CheckError(resp *http.Response) (err error) {
switch resp.StatusCode / 100 {
case 4, 5:
// We will attempt to parse this error outside of this block.
case 2:
return nil
default:
// TODO(jsternberg): Figure out what to do here?
return &errors2.Error{
Code: errors2.EInternal,
Msg: fmt.Sprintf("unexpected status code: %d %s", resp.StatusCode, resp.Status),
}
}
perr := &errors2.Error{
Code: StatusCodeToErrorCode(resp.StatusCode),
}
if resp.StatusCode == http.StatusUnsupportedMediaType {
perr.Msg = fmt.Sprintf("invalid media type: %q", resp.Header.Get("Content-Type"))
return perr
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// Assume JSON if there is no content-type.
contentType = "application/json"
}
mediatype, _, _ := mime.ParseMediaType(contentType)
var buf bytes.Buffer
if _, err := io.Copy(&buf, resp.Body); err != nil {
perr.Msg = "failed to read error response"
perr.Err = err
return perr
}
switch mediatype {
case "application/json":
if err := json.Unmarshal(buf.Bytes(), perr); err != nil {
perr.Msg = fmt.Sprintf("attempted to unmarshal error as JSON but failed: %q", err)
perr.Err = firstLineAsError(buf)
}
default:
perr.Err = firstLineAsError(buf)
}
if perr.Code == "" {
// given it was unset during attempt to unmarshal as JSON
perr.Code = StatusCodeToErrorCode(resp.StatusCode)
}
return perr
}
func firstLineAsError(buf bytes.Buffer) error {
line, _ := buf.ReadString('\n')
return errors.New(strings.TrimSuffix(line, "\n"))
}

View File

@ -0,0 +1,56 @@
package http_test
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/influxdata/influxdb/kit/platform/errors"
kithttp "github.com/influxdata/influxdb/kit/transport/http"
)
func TestEncodeError(t *testing.T) {
ctx := context.TODO()
w := httptest.NewRecorder()
kithttp.ErrorHandler(0).HandleHTTPError(ctx, nil, w)
if w.Code != 200 {
t.Errorf("expected status code 200, got: %d", w.Code)
}
}
func TestEncodeErrorWithError(t *testing.T) {
ctx := context.TODO()
err := &errors.Error{
Code: errors.EInternal,
Msg: "an error occurred",
Err: fmt.Errorf("there's an error here, be aware"),
}
w := httptest.NewRecorder()
kithttp.ErrorHandler(0).HandleHTTPError(ctx, err, w)
if w.Code != 500 {
t.Errorf("expected status code 500, got: %d", w.Code)
}
errHeader := w.Header().Get("X-Platform-Error-Code")
if errHeader != errors.EInternal {
t.Errorf("expected X-Platform-Error-Code: %s, got: %s", errors.EInternal, errHeader)
}
// The http handler will flatten the message and it will not
// have an error property, so reading the serialization results
// in a different error.
pe := kithttp.CheckError(w.Result()).(*errors.Error)
if want, got := errors.EInternal, pe.Code; want != got {
t.Errorf("unexpected code -want/+got:\n\t- %q\n\t+ %q", want, got)
}
if want, got := "an error occurred: there's an error here, be aware", pe.Msg; want != got {
t.Errorf("unexpected message -want/+got:\n\t- %q\n\t+ %q", want, got)
}
}

View File

@ -0,0 +1,39 @@
package http
import (
"context"
"net/http"
"github.com/influxdata/influxdb/kit/feature"
)
// Enabler allows the switching between two HTTP Handlers
type Enabler interface {
Enabled(ctx context.Context, flagger ...feature.Flagger) bool
}
// FeatureHandler is used to switch requests between an existing and a feature flagged
// HTTP Handler on a per-request basis
type FeatureHandler struct {
enabler Enabler
flagger feature.Flagger
oldHandler http.Handler
newHandler http.Handler
prefix string
}
func NewFeatureHandler(e Enabler, f feature.Flagger, old, new http.Handler, prefix string) *FeatureHandler {
return &FeatureHandler{e, f, old, new, prefix}
}
func (h *FeatureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.enabler.Enabled(r.Context(), h.flagger) {
h.newHandler.ServeHTTP(w, r)
return
}
h.oldHandler.ServeHTTP(w, r)
}
func (h *FeatureHandler) Prefix() string {
return h.prefix
}

View File

@ -0,0 +1,11 @@
package http
import "net/http"
// ResourceHandler is an HTTP handler for a resource. The prefix
// describes the url path prefix that relates to the handler
// endpoints.
type ResourceHandler interface {
Prefix() string
http.Handler
}

View File

@ -0,0 +1,182 @@
package http
import (
"context"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/influxdata/influxdb/kit/platform"
"github.com/influxdata/influxdb/kit/platform/errors"
"github.com/go-chi/chi"
"github.com/influxdata/influxdb/kit/tracing"
ua "github.com/mileusna/useragent"
"github.com/prometheus/client_golang/prometheus"
)
// Middleware constructor.
type Middleware func(http.Handler) http.Handler
func SetCORS(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if origin := r.Header.Get("Origin"); origin != "" {
// Access-Control-Allow-Origin must be present in every response
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if r.Method == http.MethodOptions {
// allow and stop processing in pre-flight requests
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, User-Agent")
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func Metrics(name string, reqMetric *prometheus.CounterVec, durMetric *prometheus.HistogramVec) Middleware {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
statusW := NewStatusResponseWriter(w)
defer func(start time.Time) {
label := prometheus.Labels{
"handler": name,
"method": r.Method,
"path": normalizePath(r.URL.Path),
"status": statusW.StatusCodeClass(),
"response_code": fmt.Sprintf("%d", statusW.Code()),
"user_agent": UserAgent(r),
}
durMetric.With(label).Observe(time.Since(start).Seconds())
reqMetric.With(label).Inc()
}(time.Now())
next.ServeHTTP(statusW, r)
}
return http.HandlerFunc(fn)
}
}
func SkipOptions(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// Preflight CORS requests from the browser will send an options request,
// so we need to make sure we satisfy them
if origin := r.Header.Get("Origin"); origin == "" && r.Method == http.MethodOptions {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func Trace(name string) Middleware {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
span, r := tracing.ExtractFromHTTPRequest(r, name)
defer span.Finish()
span.LogKV("user_agent", UserAgent(r))
for k, v := range r.Header {
if len(v) == 0 {
continue
}
if k == "Authorization" || k == "User-Agent" {
continue
}
// If header has multiple values, only the first value will be logged on the trace.
span.LogKV(k, v[0])
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
func UserAgent(r *http.Request) string {
header := r.Header.Get("User-Agent")
if header == "" {
return "unknown"
}
return ua.Parse(header).Name
}
func normalizePath(p string) string {
var parts []string
for head, tail := shiftPath(p); ; head, tail = shiftPath(tail) {
piece := head
if len(piece) == platform.IDLength {
if _, err := platform.IDFromString(head); err == nil {
piece = ":id"
}
}
parts = append(parts, piece)
if tail == "/" {
break
}
}
return "/" + path.Join(parts...)
}
func shiftPath(p string) (head, tail string) {
p = path.Clean("/" + p)
i := strings.Index(p[1:], "/") + 1
if i <= 0 {
return p[1:], "/"
}
return p[1:i], p[i:]
}
type OrgContext string
const CtxOrgKey OrgContext = "orgID"
// ValidResource make sure a resource exists when a sub system needs to be mounted to an api
func ValidResource(api *API, lookupOrgByResourceID func(context.Context, platform.ID) (platform.ID, error)) Middleware {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
statusW := NewStatusResponseWriter(w)
id, err := platform.IDFromString(chi.URLParam(r, "id"))
if err != nil {
api.Err(w, r, platform.ErrCorruptID(err))
return
}
ctx := r.Context()
orgID, err := lookupOrgByResourceID(ctx, *id)
if err != nil {
// if this function returns an error we will squash the error message and replace it with a not found error
api.Err(w, r, &errors.Error{
Code: errors.ENotFound,
Msg: "404 page not found",
})
return
}
// embed OrgID into context
next.ServeHTTP(statusW, r.WithContext(context.WithValue(ctx, CtxOrgKey, orgID)))
}
return http.HandlerFunc(fn)
}
}
// OrgIDFromContext ....
func OrgIDFromContext(ctx context.Context) *platform.ID {
v := ctx.Value(CtxOrgKey)
if v == nil {
return nil
}
id := v.(platform.ID)
return &id
}

View File

@ -0,0 +1,95 @@
package http
import (
"net/http"
"path"
"testing"
"github.com/influxdata/influxdb/kit/platform"
"github.com/influxdata/influxdb/pkg/testttp"
"github.com/stretchr/testify/assert"
)
func Test_normalizePath(t *testing.T) {
tests := []struct {
name string
path string
expected string
}{
{
name: "1",
path: path.Join("/api/v2/organizations", platform.ID(2).String()),
expected: "/api/v2/organizations/:id",
},
{
name: "2",
path: "/api/v2/organizations",
expected: "/api/v2/organizations",
},
{
name: "3",
path: "/",
expected: "/",
},
{
name: "4",
path: path.Join("/api/v2/organizations", platform.ID(2).String(), "users", platform.ID(3).String()),
expected: "/api/v2/organizations/:id/users/:id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := normalizePath(tt.path)
assert.Equal(t, tt.expected, actual)
})
}
}
func TestCors(t *testing.T) {
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("nextHandler"))
})
tests := []struct {
name string
method string
headers []string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS without Origin",
method: "OPTIONS",
expectedStatus: http.StatusMethodNotAllowed,
},
{
name: "OPTIONS with Origin",
method: "OPTIONS",
headers: []string{"Origin", "http://myapp.com"},
expectedStatus: http.StatusNoContent,
},
{
name: "GET with Origin",
method: "GET",
headers: []string{"Origin", "http://anotherapp.com"},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "http://anotherapp.com",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svr := SkipOptions(SetCORS(nextHandler))
testttp.
HTTP(t, tt.method, "/", nil).
Headers("", "", tt.headers...).
Do(svr).
ExpectStatus(tt.expectedStatus).
ExpectHeaders(tt.expectedHeaders)
})
}
}

View File

@ -0,0 +1,58 @@
package http
import "net/http"
type StatusResponseWriter struct {
statusCode int
responseBytes int
http.ResponseWriter
}
func NewStatusResponseWriter(w http.ResponseWriter) *StatusResponseWriter {
return &StatusResponseWriter{
ResponseWriter: w,
}
}
func (w *StatusResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
w.responseBytes += n
return n, err
}
// WriteHeader writes the header and captures the status code.
func (w *StatusResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *StatusResponseWriter) Code() int {
code := w.statusCode
if code == 0 {
// When statusCode is 0 then WriteHeader was never called and we can assume that
// the ResponseWriter wrote an http.StatusOK.
code = http.StatusOK
}
return code
}
func (w *StatusResponseWriter) ResponseBytes() int {
return w.responseBytes
}
func (w *StatusResponseWriter) StatusCodeClass() string {
class := "XXX"
switch w.Code() / 100 {
case 1:
class = "1XX"
case 2:
class = "2XX"
case 3:
class = "3XX"
case 4:
class = "4XX"
case 5:
class = "5XX"
}
return class
}

195
pkg/testttp/http.go Normal file
View File

@ -0,0 +1,195 @@
package testttp
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
// Req is a request builder.
type Req struct {
t testing.TB
req *http.Request
}
// HTTP runs creates a request for an http call.
func HTTP(t testing.TB, method, addr string, body io.Reader) *Req {
return &Req{
t: t,
req: httptest.NewRequest(method, addr, body),
}
}
// Delete creates a DELETE request.
func Delete(t testing.TB, addr string) *Req {
return HTTP(t, http.MethodDelete, addr, nil)
}
// Get creates a GET request.
func Get(t testing.TB, addr string) *Req {
return HTTP(t, http.MethodGet, addr, nil)
}
// Patch creates a PATCH request.
func Patch(t testing.TB, addr string, body io.Reader) *Req {
return HTTP(t, http.MethodPatch, addr, body)
}
// PatchJSON creates a PATCH request with a json encoded body.
func PatchJSON(t testing.TB, addr string, v interface{}) *Req {
return HTTP(t, http.MethodPatch, addr, mustEncodeJSON(t, v))
}
// Post creates a POST request.
func Post(t testing.TB, addr string, body io.Reader) *Req {
return HTTP(t, http.MethodPost, addr, body)
}
// PostJSON returns a POST request with a json encoded body.
func PostJSON(t testing.TB, addr string, v interface{}) *Req {
return Post(t, addr, mustEncodeJSON(t, v))
}
// Put creates a PUT request.
func Put(t testing.TB, addr string, body io.Reader) *Req {
return HTTP(t, http.MethodPut, addr, body)
}
// PutJSON creates a PUT request with a json encoded body.
func PutJSON(t testing.TB, addr string, v interface{}) *Req {
return HTTP(t, http.MethodPut, addr, mustEncodeJSON(t, v))
}
// Do runs the request against the provided handler.
func (r *Req) Do(handler http.Handler) *Resp {
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r.req)
return &Resp{
t: r.t,
debug: true,
Req: r.req,
Rec: rec,
}
}
func (r *Req) SetFormValue(k, v string) *Req {
if r.req.Form == nil {
r.req.Form = make(url.Values)
}
r.req.Form.Set(k, v)
return r
}
// Headers allows the user to set headers on the http request.
func (r *Req) Headers(k, v string, rest ...string) *Req {
headers := append(rest, k, v)
for i := 0; i < len(headers); i += 2 {
if i+1 >= len(headers) {
break
}
k, v := headers[i], headers[i+1]
r.req.Header.Add(k, v)
}
return r
}
// WithCtx sets the ctx on the request.
func (r *Req) WithCtx(ctx context.Context) *Req {
r.req = r.req.WithContext(ctx)
return r
}
// WrapCtx provides means to wrap a request context. This is useful for stuffing in the
// auth stuffs that are required at times.
func (r *Req) WrapCtx(fn func(ctx context.Context) context.Context) *Req {
return r.WithCtx(fn(r.req.Context()))
}
// Resp is a http recorder wrapper.
type Resp struct {
t testing.TB
debug bool
Req *http.Request
Rec *httptest.ResponseRecorder
}
// Debug sets the debugger. If true, the debugger will print the body of the response
// when the expected status is not received.
func (r *Resp) Debug(b bool) *Resp {
r.debug = b
return r
}
// Expect allows the assertions against the raw Resp.
func (r *Resp) Expect(fn func(*Resp)) *Resp {
fn(r)
return r
}
// ExpectStatus compares the expected status code against the recorded status code.
func (r *Resp) ExpectStatus(code int) *Resp {
r.t.Helper()
if r.Rec.Code != code {
r.t.Errorf("unexpected status code: expected=%d got=%d", code, r.Rec.Code)
if r.debug {
r.t.Logf("body: %v", r.Rec.Body.String())
}
}
return r
}
// ExpectBody provides an assertion against the recorder body.
func (r *Resp) ExpectBody(fn func(body *bytes.Buffer)) *Resp {
fn(r.Rec.Body)
return r
}
// ExpectHeaders asserts that multiple headers with values exist in the recorder.
func (r *Resp) ExpectHeaders(h map[string]string) *Resp {
for k, v := range h {
r.ExpectHeader(k, v)
}
return r
}
// ExpectHeader asserts that the header is in the recorder.
func (r *Resp) ExpectHeader(k, v string) *Resp {
r.t.Helper()
vals, ok := r.Rec.Header()[k]
if !ok {
r.t.Errorf("did not find expected header: %q", k)
return r
}
for _, vv := range vals {
if vv == v {
return r
}
}
r.t.Errorf("did not find expected value for header %q; got: %v", k, vals)
return r
}
func mustEncodeJSON(t testing.TB, v interface{}) *bytes.Buffer {
t.Helper()
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(v); err != nil {
t.Fatal(err)
}
return &buf
}

174
pkg/testttp/http_test.go Normal file
View File

@ -0,0 +1,174 @@
package testttp_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/influxdata/influxdb/pkg/testttp"
)
func TestHTTP(t *testing.T) {
svr := newMux()
t.Run("Delete", func(t *testing.T) {
testttp.
Delete(t, "/").
Do(svr).
ExpectStatus(http.StatusNoContent)
})
t.Run("Get", func(t *testing.T) {
testttp.
Get(t, "/").
Do(svr).
ExpectStatus(http.StatusOK).
ExpectBody(assertBody(t, http.MethodGet))
})
t.Run("Patch", func(t *testing.T) {
testttp.
Patch(t, "/", nil).
Do(svr).
ExpectStatus(http.StatusPartialContent).
ExpectBody(assertBody(t, http.MethodPatch))
})
t.Run("PatchJSON", func(t *testing.T) {
testttp.
PatchJSON(t, "/", map[string]string{"k": "t"}).
Do(svr).
ExpectStatus(http.StatusPartialContent).
ExpectBody(assertBody(t, http.MethodPatch))
})
t.Run("Post", func(t *testing.T) {
t.Run("basic", func(t *testing.T) {
testttp.
Post(t, "/", nil).
Do(svr).
ExpectStatus(http.StatusCreated).
ExpectBody(assertBody(t, http.MethodPost))
})
t.Run("with form values", func(t *testing.T) {
svr := http.NewServeMux()
svr.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.FormValue("key")))
}))
testttp.
Post(t, "/", nil).
SetFormValue("key", "val").
Do(svr).
ExpectStatus(http.StatusOK).
ExpectBody(func(body *bytes.Buffer) {
if expected, got := "val", body.String(); expected != got {
t.Fatalf("did not get form value; expected=%q got=%q", expected, got)
}
})
})
})
t.Run("PostJSON", func(t *testing.T) {
testttp.
PostJSON(t, "/", map[string]string{"k": "v"}).
Do(svr).
ExpectStatus(http.StatusCreated).
ExpectBody(assertBody(t, http.MethodPost))
})
t.Run("Put", func(t *testing.T) {
testttp.
Put(t, "/", nil).
Do(svr).
ExpectStatus(http.StatusAccepted).
ExpectBody(assertBody(t, http.MethodPut))
})
t.Run("PutJSON", func(t *testing.T) {
testttp.
PutJSON(t, "/", map[string]string{"k": "t"}).
Do(svr).
ExpectStatus(http.StatusAccepted).
ExpectBody(assertBody(t, http.MethodPut))
})
t.Run("Headers", func(t *testing.T) {
testttp.
Post(t, "/", strings.NewReader(`a: foo`)).
Headers("Content-Type", "text/yml").
Do(svr).
Expect(func(resp *testttp.Resp) {
equals(t, "text/yml", resp.Req.Header.Get("Content-Type"))
})
})
}
type foo struct {
Name, Thing, Method string
}
func newMux() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodGet:
writeFn(w, req.Method, http.StatusOK)
case http.MethodPost:
writeFn(w, req.Method, http.StatusCreated)
case http.MethodPut:
writeFn(w, req.Method, http.StatusAccepted)
case http.MethodPatch:
writeFn(w, req.Method, http.StatusPartialContent)
case http.MethodDelete:
w.WriteHeader(http.StatusNoContent)
}
})
return mux
}
func assertBody(t *testing.T, method string) func(*bytes.Buffer) {
return func(buf *bytes.Buffer) {
var f foo
if err := json.NewDecoder(buf).Decode(&f); err != nil {
t.Fatal(err)
}
expected := foo{Name: "name", Thing: "thing", Method: method}
equals(t, expected, f)
}
}
func writeFn(w http.ResponseWriter, method string, statusCode int) {
f := foo{Name: "name", Thing: "thing", Method: method}
r, err := encodeBuf(f)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(statusCode)
if _, err := io.Copy(w, r); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func equals(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected == actual {
return
}
t.Errorf("expected: %v\tactual: %v", expected, actual)
}
func encodeBuf(v interface{}) (io.Reader, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(v); err != nil {
return nil, err
}
return &buf, nil
}