diff --git a/raft/handler.go b/raft/handler.go index 2d517165e1..7f8018fecb 100644 --- a/raft/handler.go +++ b/raft/handler.go @@ -16,7 +16,7 @@ type Handler struct { RemovePeer(id uint64) error Heartbeat(term, commitIndex, leaderID uint64) (currentIndex uint64, err error) WriteEntriesTo(w io.Writer, id, term, index uint64) error - RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) error + RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (peerTerm uint64, err error) } } @@ -123,7 +123,7 @@ func (h *Handler) serveHeartbeat(w http.ResponseWriter, r *http.Request) { // Execute heartbeat on the log. currentIndex, err := h.Log.Heartbeat(term, commitIndex, leaderID) - // Return current term and index. + // Return current index. w.Header().Set("X-Raft-Index", strconv.FormatUint(currentIndex, 10)) // Write error, if applicable. @@ -201,8 +201,14 @@ func (h *Handler) serveRequestVote(w http.ResponseWriter, r *http.Request) { return } + // Request vote from log. + peerTerm, err := h.Log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm) + + // Write current term. + w.Header().Set("X-Raft-Term", strconv.FormatUint(peerTerm, 10)) + // Write error, if applicable. - if err := h.Log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm); err != nil { + if err != nil { w.Header().Set("X-Raft-Error", err.Error()) w.WriteHeader(http.StatusInternalServerError) return diff --git a/raft/handler_test.go b/raft/handler_test.go index bb36ac19cf..22fd8f984e 100644 --- a/raft/handler_test.go +++ b/raft/handler_test.go @@ -276,7 +276,7 @@ func TestHandler_HandleStream_Error(t *testing.T) { // Ensure a vote request can be sent over HTTP. func TestHandler_HandleRequestVote(t *testing.T) { h := NewHandler() - h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) error { + h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) { if term != 1 { t.Fatalf("unexpected term: %d", term) } else if candidateID != 2 { @@ -286,7 +286,7 @@ func TestHandler_HandleRequestVote(t *testing.T) { } else if lastLogTerm != 4 { t.Fatalf("unexpected last log term: %d", lastLogTerm) } - return nil + return 5, nil } s := httptest.NewServer(h) defer s.Close() @@ -298,6 +298,8 @@ func TestHandler_HandleRequestVote(t *testing.T) { t.Fatalf("unexpected error: %s", err) } else if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status: %d", resp.StatusCode) + } else if term := resp.Header.Get("X-Raft-Term"); term != "5" { + t.Fatalf("unexpected raft term: %s", term) } else if s := resp.Header.Get("X-Raft-Error"); s != "" { t.Fatalf("unexpected raft error: %s", s) } @@ -306,8 +308,8 @@ func TestHandler_HandleRequestVote(t *testing.T) { // Ensure sending invalid parameters in a vote request returns an error. func TestHandler_HandleRequestVote_Error(t *testing.T) { h := NewHandler() - h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) error { - return raft.ErrStaleTerm + h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) { + return 100, raft.ErrStaleTerm } s := httptest.NewServer(h) defer s.Close() @@ -373,7 +375,7 @@ type Handler struct { RemovePeerFunc func(id uint64) error HeartbeatFunc func(term, commitIndex, leaderID uint64) (currentIndex uint64, err error) WriteEntriesToFunc func(w io.Writer, id, term, index uint64) error - RequestVoteFunc func(term, candidateID, lastLogIndex, lastLogTerm uint64) error + RequestVoteFunc func(term, candidateID, lastLogIndex, lastLogTerm uint64) (peerTerm uint64, err error) } // NewHandler returns a new instance of Handler. @@ -394,6 +396,6 @@ func (h *Handler) WriteEntriesTo(w io.Writer, id, term, index uint64) error { return h.WriteEntriesToFunc(w, id, term, index) } -func (h *Handler) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) error { +func (h *Handler) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) { return h.RequestVoteFunc(term, candidateID, lastLogIndex, lastLogTerm) } diff --git a/raft/log.go b/raft/log.go index d68af64608..9089f95d41 100644 --- a/raft/log.go +++ b/raft/log.go @@ -172,7 +172,7 @@ type Log struct { Leave(u url.URL, id uint64) error Heartbeat(u url.URL, term, commitIndex, leaderID uint64) (lastIndex uint64, err error) ReadFrom(u url.URL, id, term, index uint64) (io.ReadCloser, error) - RequestVote(u url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) error + RequestVote(u url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (peerTerm uint64, err error) } // Clock is an abstraction of time. @@ -200,8 +200,8 @@ func NewLog() *Log { Clock: NewClock(), Transport: &HTTPTransport{}, Rand: rand.NewSource(time.Now().UnixNano()).Int63, - heartbeats: make(chan heartbeat, 1), - terms: make(chan uint64, 1), + heartbeats: make(chan heartbeat, 10), + terms: make(chan uint64, 10), Logger: log.New(os.Stderr, "[raft] ", log.LstdFlags), } l.updateLogPrefix() @@ -525,10 +525,23 @@ func (l *Log) writeTerm(term uint64) error { } // setTerm sets the current term and clears the vote. -func (l *Log) setTerm(term uint64) { +func (l *Log) setTerm(term uint64) error { l.Logger.Printf("changing term: %d => %d", l.term, term) + + if err := l.writeTerm(term); err != nil { + return err + } + l.term = term l.votedFor = 0 + return nil +} + +// mustSetTerm sets the current term and clears the vote. Panic on error. +func (l *Log) mustSetTerm(term uint64) { + if err := l.setTerm(term); err != nil { + panic("unable to set term: " + err.Error()) + } } // readConfig reads the configuration from disk. @@ -602,10 +615,9 @@ func (l *Log) Initialize() error { // Automatically promote to leader. term := uint64(1) - if err := l.writeTerm(term); err != nil { - return fmt.Errorf("write term: %s", err) + if err := l.setTerm(term); err != nil { + return fmt.Errorf("set term: %s", err) } - l.setTerm(term) l.lastLogTerm = term l.leaderID = l.id @@ -851,7 +863,7 @@ func (l *Log) followerLoop(closing <-chan struct{}) State { // Update term, commit index & leader. l.mu.Lock() if hb.term > l.term { - l.setTerm(hb.term) + l.mustSetTerm(hb.term) } if hb.commitIndex > l.commitIndex { l.commitIndex = hb.commitIndex @@ -862,7 +874,7 @@ func (l *Log) followerLoop(closing <-chan struct{}) State { case term := <-l.terms: l.mu.Lock() if term > l.term { - l.setTerm(term) + l.mustSetTerm(term) } l.mu.Unlock() } @@ -969,7 +981,7 @@ func (l *Log) candidateLoop(closing <-chan struct{}) State { case hb := <-l.heartbeats: l.mu.Lock() if hb.term >= term { - l.setTerm(hb.term) + l.mustSetTerm(hb.term) l.leaderID = hb.leaderID l.mu.Unlock() return Follower @@ -984,7 +996,7 @@ func (l *Log) candidateLoop(closing <-chan struct{}) State { // Check against the current term since that may have changed. l.mu.Lock() if newTerm >= l.term { - l.setTerm(newTerm) + l.mustSetTerm(newTerm) l.mu.Unlock() return Follower } @@ -1018,8 +1030,15 @@ func (l *Log) elect(term uint64, elected chan struct{}, wg *sync.WaitGroup) { continue } go func(n *ConfigNode) { - if err := l.Transport.RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm); err != nil { - l.tracef("sendVoteRequests: %s: %s", n.URL.String(), err) + peerTerm, err := l.Transport.RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm) + l.Logger.Printf("send req vote(term=%d, candidateID=%d, lastLogIndex=%d, lastLogTerm=%d) (term=%d, err=%v)", term, id, lastLogIndex, lastLogTerm, peerTerm, err) + + // If an error occured then send the peer's term. + if err != nil { + select { + case l.terms <- peerTerm: + default: + } return } votes <- struct{}{} @@ -1075,7 +1094,7 @@ func (l *Log) leaderLoop(closing <-chan struct{}) State { case newTerm := <-l.terms: // step down on higher term if newTerm > term { l.mu.Lock() - l.setTerm(newTerm) + l.mustSetTerm(newTerm) l.truncateTo(l.commitIndex) l.mu.Unlock() return Follower @@ -1085,7 +1104,7 @@ func (l *Log) leaderLoop(closing <-chan struct{}) State { case hb := <-l.heartbeats: // step down on higher term if hb.term > term { l.mu.Lock() - l.setTerm(hb.term) + l.mustSetTerm(hb.term) l.truncateTo(l.commitIndex) l.mu.Unlock() return Follower @@ -1570,17 +1589,17 @@ func (l *Log) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex uint64 } // RequestVote requests a vote from the log. -func (l *Log) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (err error) { +func (l *Log) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (peerTerm uint64, err error) { l.mu.Lock() defer l.mu.Unlock() // Check if log is closed. if !l.opened() { - return ErrClosed + return l.term, ErrClosed } defer func() { - l.tracef("RV(term=%d, candidateID=%d, lastLogIndex=%d, lastLogTerm=%d) (err=%v)", term, candidateID, lastLogIndex, lastLogTerm, err) + l.Logger.Printf("recv req vote(term=%d, candidateID=%d, lastLogIndex=%d, lastLogTerm=%d) (err=%v)", term, candidateID, lastLogIndex, lastLogTerm, err) }() // Deny vote if: @@ -1588,13 +1607,13 @@ func (l *Log) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) ( // 2. Already voted for a different candidate in this term. (§5.2) // 3. Candidate log is less up-to-date than local log. (§5.4) if term < l.term { - return ErrStaleTerm + return l.term, ErrStaleTerm } else if term == l.term && l.votedFor != 0 && l.votedFor != candidateID { - return ErrAlreadyVoted + return l.term, ErrAlreadyVoted } else if lastLogTerm < l.lastLogTerm { - return ErrOutOfDateLog + return l.term, ErrOutOfDateLog } else if lastLogTerm == l.lastLogTerm && lastLogIndex < l.lastLogIndex { - return ErrOutOfDateLog + return l.term, ErrOutOfDateLog } // Notify term change. @@ -1609,7 +1628,7 @@ func (l *Log) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) ( l.term = term l.votedFor = candidateID - return nil + return l.term, nil } // WriteEntriesTo attaches a writer to the log from a given index. diff --git a/raft/transport.go b/raft/transport.go index ab35f5ed36..965a0f0c71 100644 --- a/raft/transport.go +++ b/raft/transport.go @@ -140,7 +140,7 @@ func (t *HTTPTransport) ReadFrom(uri url.URL, id, term, index uint64) (io.ReadCl } // RequestVote requests a vote for a candidate in a given term. -func (t *HTTPTransport) RequestVote(uri url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) error { +func (t *HTTPTransport) RequestVote(uri url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) { // Construct URL. u := uri u.Path = path.Join(u.Path, "raft/vote") @@ -156,14 +156,20 @@ func (t *HTTPTransport) RequestVote(uri url.URL, term, candidateID, lastLogIndex // Send HTTP request. resp, err := http.Get(u.String()) if err != nil { - return err + return 0, err } _ = resp.Body.Close() - // Parse returned error. - if s := resp.Header.Get("X-Raft-Error"); s != "" { - return errors.New(s) + // Parse returned term. + peerTerm, err := strconv.ParseUint(resp.Header.Get("X-Raft-Term"), 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid returned term: %q", resp.Header.Get("X-Raft-Term")) } - return nil + // Parse returned error. + if s := resp.Header.Get("X-Raft-Error"); s != "" { + return peerTerm, errors.New(s) + } + + return peerTerm, nil } diff --git a/raft/transport_test.go b/raft/transport_test.go index 2d74ebd177..5f7fcb7ec4 100644 --- a/raft/transport_test.go +++ b/raft/transport_test.go @@ -322,14 +322,17 @@ func TestHTTPTransport_RequestVote(t *testing.T) { if lastLogTerm := r.FormValue("lastLogTerm"); lastLogTerm != `4` { t.Fatalf("unexpected last log term: %v", lastLogTerm) } + w.Header().Set("X-Raft-Term", `100`) w.WriteHeader(http.StatusOK) })) defer s.Close() // Execute heartbeat against test server. u, _ := url.Parse(s.URL) - if err := (&raft.HTTPTransport{}).RequestVote(*u, 1, 2, 3, 4); err != nil { + if peerTerm, err := (&raft.HTTPTransport{}).RequestVote(*u, 1, 2, 3, 4); err != nil { t.Fatalf("unexpected error: %s", err) + } else if peerTerm != 100 { + t.Fatalf("unexpected peer term: %d", peerTerm) } } @@ -343,7 +346,7 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) { defer s.Close() u, _ := url.Parse(s.URL) - if err := (&raft.HTTPTransport{}).RequestVote(*u, 0, 0, 0, 0); err == nil { + if _, err := (&raft.HTTPTransport{}).RequestVote(*u, 0, 0, 0, 0); err == nil { t.Errorf("expected error") } else if err.Error() != `already voted` { t.Errorf("unexpected error: %s", err) @@ -353,7 +356,7 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) { // Ensure that requesting a vote over HTTP to a stopped server returns an error. func TestHTTPTransport_RequestVote_ErrConnectionRefused(t *testing.T) { u, _ := url.Parse("http://localhost:41932") - if err := (&raft.HTTPTransport{}).RequestVote(*u, 0, 0, 0, 0); err == nil { + if _, err := (&raft.HTTPTransport{}).RequestVote(*u, 0, 0, 0, 0); err == nil { t.Fatal("expected error") } else if !is_connection_refused(err) { t.Fatalf("unexpected error: %s", err) @@ -430,10 +433,10 @@ func (t *Transport) ReadFrom(u url.URL, id, term, index uint64) (io.ReadCloser, } // RequestVote calls RequestVote() on the target log. -func (t *Transport) RequestVote(u url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) error { +func (t *Transport) RequestVote(u url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) { l, err := t.log(u) if err != nil { - return err + return 0, err } return l.RequestVote(term, candidateID, lastLogIndex, lastLogTerm) }