diff --git a/httpd/handler.go b/httpd/handler.go index 2cb601e2bb..1c33b862c9 100644 --- a/httpd/handler.go +++ b/httpd/handler.go @@ -308,36 +308,68 @@ func (h *Handler) serveWait(w http.ResponseWriter, r *http.Request) { } else { d = time.Duration(timeout) * time.Millisecond } - err := h.pollForIndex(index, d) - if err != nil { - w.WriteHeader(http.StatusRequestTimeout) - return + poller := &indexPoller{ + h: h, + quit: make(chan bool), + } + if notify, ok := w.(http.CloseNotifier); ok { + go func(poller *indexPoller) { + <-notify.CloseNotify() + poller.Quit() + }(poller) + } + + err := poller.PollForIndex(index, d) + switch err { + case errPollTimedOut: + w.WriteHeader(http.StatusRequestTimeout) + case nil: + w.Write([]byte(fmt.Sprintf("%d", h.server.Index()))) } - w.Write([]byte(fmt.Sprintf("%d", h.server.Index()))) } -// pollForIndex will poll until either the index is met or it times out -// timeout is in milliseconds -func (h *Handler) pollForIndex(index uint64, timeout time.Duration) error { - done := make(chan struct{}) +type indexPoller struct { + h *Handler + quit chan bool +} +var ( + errPollTimedOut = errors.New("timed out") + errAborted = errors.New("aborted") +) + +func (p *indexPoller) PollForIndex(index uint64, timeout time.Duration) error { + aborted := make(chan bool) + done := make(chan bool) go func() { for { - if h.server.Index() >= index { - done <- struct{}{} + select { + case <-p.quit: + aborted <- true + return + default: + if p.h.server.Index() >= index { + done <- true + return + } + time.Sleep(10 * time.Millisecond) } - time.Sleep(10 * time.Millisecond) } }() - for { - select { - case <-done: - return nil - case <-time.After(timeout): - return fmt.Errorf("timed out") - } + select { + case <-time.After(timeout): + return errPollTimedOut + case <-aborted: + return errAborted + case <-done: + return nil } + return nil +} + +func (p *indexPoller) Quit() { + p.quit <- true } // serveDataNodes returns a list of all data nodes in the cluster.