diff --git a/kit/check/check.go b/kit/check/check.go index ac039d16dc..f2dba60153 100644 --- a/kit/check/check.go +++ b/kit/check/check.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "sort" + "sync/atomic" ) // Status string to indicate the overall status of the check. @@ -25,8 +26,10 @@ const ( // Check wraps a map of service names to status checkers. type Check struct { - healthChecks []Checker - readyChecks []Checker + healthChecks []Checker + readyChecks []Checker + manualOverride atomic.Value + manualHealthState atomic.Value passthroughHandler http.Handler } @@ -38,7 +41,10 @@ type Checker interface { // NewCheck returns a Health with a default checker. func NewCheck() *Check { - return &Check{} + ch := &Check{} + ch.manualOverride.Store(false) + ch.manualHealthState.Store(false) + return ch } // AddHealthCheck adds the check to the list of ready checks. @@ -68,9 +74,22 @@ func (c *Check) CheckHealth(ctx context.Context) Response { Status: StatusPass, Checks: make(Responses, len(c.healthChecks)), } + override := c.manualOverride.Load().(bool) + if override { + if c.manualHealthState.Load().(bool) { + response.Status = StatusPass + } else { + response.Status = StatusFail + } + overrideResponse := Response{ + Name: "manual-override", + Message: "health manually overriden", + } + response.Checks = append(response.Checks, overrideResponse) + } for i, ch := range c.healthChecks { resp := ch.Check(ctx) - if resp.Status != StatusPass { + if resp.Status != StatusPass && !override { response.Status = resp.Status } response.Checks[i] = resp @@ -125,6 +144,19 @@ func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/ready": resp = c.CheckReady(r.Context()) case "/health": + query := r.URL.Query() + switch query.Get("force") { + case "true": + c.manualOverride.Store(true) + switch query.Get("healthy") { + case "true": + c.manualHealthState.Store(true) + case "false": + c.manualHealthState.Store(false) + } + case "false": + c.manualOverride.Store(false) + } resp = c.CheckHealth(r.Context()) } diff --git a/kit/check/check_test.go b/kit/check/check_test.go index fdbd695e62..fb24289b54 100644 --- a/kit/check/check_test.go +++ b/kit/check/check_test.go @@ -165,6 +165,66 @@ func TestHealthSorting(t *testing.T) { } } +func TestForceHealth(t *testing.T) { + c, ts := buildCheckWithServer() + defer ts.Close() + + c.AddHealthCheck(mockPass("a")) + + _, err := http.Get(ts.URL + "/health?force=true&healthy=false") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err := respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + expected := &Response{ + Name: "Health", + Status: "fail", + Checks: Responses{ + Response{Name: "manual-override", Message: "health manually overriden"}, + Response{Name: "a", Status: "pass"}, + }, + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } + + _, err = http.Get(ts.URL + "/health?force=false") + if err != nil { + t.Fatal(err) + } + + expected = &Response{ + Name: "Health", + Status: "pass", + Checks: Responses{ + Response{Name: "a", Status: "pass"}, + }, + } + + resp, err = http.Get(ts.URL + "/health") + if err != nil { + t.Fatal(err) + } + actual, err = respBuilder(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("unexpected response. expected %v, actual %v", expected, actual) + } +} + func TestNoCrossOver(t *testing.T) { c, ts := buildCheckWithServer() defer ts.Close()