Refactor raft state handling.

This commit changes the state handling of the raft log. The actions
for related to each raft state are strictly confined within that state's
loop. To transition between states, the raft log now much clean up all
its actions before moving on.

This fixes issues where goroutines were kicked off for one state but
were delayed in their scheduling so they would begin after the log
had already changed to another state.
pull/1615/head
Ben Johnson 2015-02-16 15:45:37 -07:00
parent 7b4d2675b0
commit f6ceb9bd32
7 changed files with 707 additions and 593 deletions

View File

@ -1,6 +1,7 @@
package raft
import (
"math/rand"
"time"
)
@ -39,8 +40,11 @@ func NewClock() *Clock {
// AfterApplyInterval returns a channel that fires after the apply interval.
func (c *Clock) AfterApplyInterval() <-chan chan struct{} { return newClockChan(c.ApplyInterval) }
// AfterElectionTimeout returns a channel that fires after the election timeout.
func (c *Clock) AfterElectionTimeout() <-chan chan struct{} { return newClockChan(c.ElectionTimeout) }
// AfterElectionTimeout returns a channel that fires after a duration that is
// between the election timeout and double the election timeout.
func (c *Clock) AfterElectionTimeout() <-chan chan struct{} {
return newClockChan(c.ElectionTimeout + time.Duration(rand.Intn(int(c.ElectionTimeout))))
}
// AfterHeartbeatInterval returns a channel that fires after the heartbeat interval.
func (c *Clock) AfterHeartbeatInterval() <-chan chan struct{} {

View File

@ -12,11 +12,11 @@ import (
// Handler represents an HTTP endpoint for Raft to communicate over.
type Handler struct {
Log interface {
AddPeer(u *url.URL) (uint64, *Config, error)
AddPeer(u *url.URL) (uint64, uint64, *Config, error)
RemovePeer(id uint64) error
Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error)
Heartbeat(term, commitIndex, leaderID uint64) (currentIndex uint64, err error)
WriteEntriesTo(w io.Writer, id, term, index uint64) error
RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) error
}
}
@ -60,7 +60,7 @@ func (h *Handler) serveJoin(w http.ResponseWriter, r *http.Request) {
}
// Add peer to the log.
id, config, err := h.Log.AddPeer(u)
id, leaderID, config, err := h.Log.AddPeer(u)
if err != nil {
w.Header().Set("X-Raft-Error", err.Error())
w.WriteHeader(http.StatusInternalServerError)
@ -69,6 +69,7 @@ func (h *Handler) serveJoin(w http.ResponseWriter, r *http.Request) {
// Return member's id in the cluster.
w.Header().Set("X-Raft-ID", strconv.FormatUint(id, 10))
w.Header().Set("X-Raft-Leader-ID", strconv.FormatUint(leaderID, 10))
w.WriteHeader(http.StatusOK)
// Write config to the body.
@ -120,11 +121,10 @@ func (h *Handler) serveHeartbeat(w http.ResponseWriter, r *http.Request) {
}
// Execute heartbeat on the log.
currentIndex, currentTerm, err := h.Log.Heartbeat(term, commitIndex, leaderID)
currentIndex, err := h.Log.Heartbeat(term, commitIndex, leaderID)
// Return current term and index.
w.Header().Set("X-Raft-Index", strconv.FormatUint(currentIndex, 10))
w.Header().Set("X-Raft-Term", strconv.FormatUint(currentTerm, 10))
// Write error, if applicable.
if err != nil {
@ -201,14 +201,8 @@ func (h *Handler) serveRequestVote(w http.ResponseWriter, r *http.Request) {
return
}
// Execute heartbeat on the log.
currentTerm, err := h.Log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm)
// Return current term and index.
w.Header().Set("X-Raft-Term", strconv.FormatUint(currentTerm, 10))
// Write error, if applicable.
if err != nil {
if err := h.Log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm); err != nil {
w.Header().Set("X-Raft-Error", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return

View File

@ -14,11 +14,11 @@ import (
// Ensure a node can join a cluster over HTTP.
func TestHandler_HandleJoin(t *testing.T) {
h := NewHandler()
h.AddPeerFunc = func(u *url.URL) (uint64, *raft.Config, error) {
h.AddPeerFunc = func(u *url.URL) (uint64, uint64, *raft.Config, error) {
if u.String() != "http://localhost:1000" {
t.Fatalf("unexpected url: %s", u)
}
return 2, &raft.Config{}, nil
return 2, 3, &raft.Config{}, nil
}
s := httptest.NewServer(h)
defer s.Close()
@ -34,14 +34,16 @@ func TestHandler_HandleJoin(t *testing.T) {
t.Fatalf("unexpected raft error: %s", s)
} else if s = resp.Header.Get("X-Raft-ID"); s != "2" {
t.Fatalf("unexpected raft id: %s", s)
} else if s = resp.Header.Get("X-Raft-Leader-ID"); s != "3" {
t.Fatalf("unexpected raft leader id: %s", s)
}
}
// Ensure that joining with an invalid query string with return an error.
func TestHandler_HandleJoin_Error(t *testing.T) {
h := NewHandler()
h.AddPeerFunc = func(u *url.URL) (uint64, *raft.Config, error) {
return 0, nil, raft.ErrClosed
h.AddPeerFunc = func(u *url.URL) (uint64, uint64, *raft.Config, error) {
return 0, 0, nil, raft.ErrClosed
}
s := httptest.NewServer(h)
defer s.Close()
@ -123,7 +125,7 @@ func TestHandler_HandleLeave_Error(t *testing.T) {
// Ensure a heartbeat can be sent over HTTP.
func TestHandler_HandleHeartbeat(t *testing.T) {
h := NewHandler()
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex uint64, err error) {
if term != 1 {
t.Fatalf("unexpected term: %d", term)
} else if commitIndex != 2 {
@ -131,7 +133,7 @@ func TestHandler_HandleHeartbeat(t *testing.T) {
} else if leaderID != 3 {
t.Fatalf("unexpected leader id: %d", leaderID)
}
return 4, 5, nil
return 4, nil
}
s := httptest.NewServer(h)
defer s.Close()
@ -147,8 +149,6 @@ func TestHandler_HandleHeartbeat(t *testing.T) {
t.Fatalf("unexpected raft error: %s", s)
} else if s = resp.Header.Get("X-Raft-Index"); s != "4" {
t.Fatalf("unexpected raft index: %s", s)
} else if s = resp.Header.Get("X-Raft-Term"); s != "5" {
t.Fatalf("unexpected raft term: %s", s)
}
}
@ -182,8 +182,8 @@ func TestHandler_HandleHeartbeat_Error(t *testing.T) {
// Ensure that sending a heartbeat to a closed log returns an error.
func TestHandler_HandleHeartbeat_ErrClosed(t *testing.T) {
h := NewHandler()
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
return 0, 0, raft.ErrClosed
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex uint64, err error) {
return 0, raft.ErrClosed
}
s := httptest.NewServer(h)
defer s.Close()
@ -271,7 +271,7 @@ func TestHandler_HandleStream_Error(t *testing.T) {
// Ensure a vote request can be sent over HTTP.
func TestHandler_HandleRequestVote(t *testing.T) {
h := NewHandler()
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) error {
if term != 1 {
t.Fatalf("unexpected term: %d", term)
} else if candidateID != 2 {
@ -281,7 +281,7 @@ func TestHandler_HandleRequestVote(t *testing.T) {
} else if lastLogTerm != 4 {
t.Fatalf("unexpected last log term: %d", lastLogTerm)
}
return 5, nil
return nil
}
s := httptest.NewServer(h)
defer s.Close()
@ -295,16 +295,14 @@ func TestHandler_HandleRequestVote(t *testing.T) {
t.Fatalf("unexpected status: %d", resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != "" {
t.Fatalf("unexpected raft error: %s", s)
} else if s = resp.Header.Get("X-Raft-Term"); s != "5" {
t.Fatalf("unexpected raft term: %s", s)
}
}
// Ensure sending invalid parameters in a vote request returns an error.
func TestHandler_HandleRequestVote_Error(t *testing.T) {
h := NewHandler()
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
return 0, raft.ErrStaleTerm
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) error {
return raft.ErrStaleTerm
}
s := httptest.NewServer(h)
defer s.Close()
@ -366,11 +364,11 @@ func TestHandler_Ping(t *testing.T) {
// Handler represents a test wrapper for the raft.Handler.
type Handler struct {
*raft.Handler
AddPeerFunc func(u *url.URL) (uint64, *raft.Config, error)
AddPeerFunc func(u *url.URL) (uint64, uint64, *raft.Config, error)
RemovePeerFunc func(id uint64) error
HeartbeatFunc func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error)
HeartbeatFunc func(term, commitIndex, leaderID uint64) (currentIndex uint64, err error)
WriteEntriesToFunc func(w io.Writer, id, term, index uint64) error
RequestVoteFunc func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
RequestVoteFunc func(term, candidateID, lastLogIndex, lastLogTerm uint64) error
}
// NewHandler returns a new instance of Handler.
@ -380,10 +378,10 @@ func NewHandler() *Handler {
return h
}
func (h *Handler) AddPeer(u *url.URL) (uint64, *raft.Config, error) { return h.AddPeerFunc(u) }
func (h *Handler) RemovePeer(id uint64) error { return h.RemovePeerFunc(id) }
func (h *Handler) AddPeer(u *url.URL) (uint64, uint64, *raft.Config, error) { return h.AddPeerFunc(u) }
func (h *Handler) RemovePeer(id uint64) error { return h.RemovePeerFunc(id) }
func (h *Handler) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
func (h *Handler) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex uint64, err error) {
return h.HeartbeatFunc(term, commitIndex, leaderID)
}
@ -391,6 +389,6 @@ func (h *Handler) WriteEntriesTo(w io.Writer, id, term, index uint64) error {
return h.WriteEntriesToFunc(w, id, term, index)
}
func (h *Handler) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
func (h *Handler) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) error {
return h.RequestVoteFunc(term, candidateID, lastLogIndex, lastLogTerm)
}

File diff suppressed because it is too large Load Diff

View File

@ -7,11 +7,12 @@ import (
"io"
"io/ioutil"
"log"
"math/rand"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/influxdb/influxdb/raft"
)
@ -79,7 +80,7 @@ func TestLog_Apply(t *testing.T) {
// Single node clusters should apply to FSM immediately.
l.Wait(index)
if n := len(l.FSM.Commands); n != 1 {
if n := len(l.FSM.(*FSM).Commands); n != 1 {
t.Fatalf("unexpected command count: %d", n)
}
}
@ -96,7 +97,7 @@ func TestLog_Config_Closed(t *testing.T) {
// Ensure that log ids in a cluster are set sequentially.
func TestCluster_ID_Sequential(t *testing.T) {
c := NewCluster()
c := NewCluster(fsmFunc)
defer c.Close()
for i, l := range c.Logs {
if l.ID() != uint64(i+1) {
@ -107,7 +108,7 @@ func TestCluster_ID_Sequential(t *testing.T) {
// Ensure that cluster starts with one leader and multiple followers.
func TestCluster_State(t *testing.T) {
c := NewCluster()
c := NewCluster(fsmFunc)
defer c.Close()
if state := c.Logs[0].State(); state != raft.Leader {
t.Fatalf("unexpected state(0): %s", state)
@ -122,7 +123,7 @@ func TestCluster_State(t *testing.T) {
// Ensure that each node's configuration matches in the cluster.
func TestCluster_Config(t *testing.T) {
c := NewCluster()
c := NewCluster(fsmFunc)
defer c.Close()
config := jsonify(c.Logs[0].Config())
for _, l := range c.Logs[1:] {
@ -134,7 +135,7 @@ func TestCluster_Config(t *testing.T) {
// Ensure that a command can be applied to a cluster and distributed appropriately.
func TestCluster_Apply(t *testing.T) {
c := NewCluster()
c := NewCluster(fsmFunc)
defer c.Close()
// Apply a command.
@ -149,41 +150,30 @@ func TestCluster_Apply(t *testing.T) {
c.Logs[2].MustWaitUncommitted(4)
// Should not apply immediately.
if n := len(leader.FSM.Commands); n != 0 {
if n := len(leader.FSM.(*FSM).Commands); n != 0 {
t.Fatalf("unexpected pre-heartbeat command count: %d", n)
}
// Run the heartbeat on the leader and have all logs apply.
// Only the leader should have the changes applied.
c.Logs[0].Clock.heartbeat()
c.Logs[0].HeartbeatUntil(4)
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
if n := len(c.Logs[0].FSM.Commands); n != 1 {
if n := len(c.Logs[0].FSM.(*FSM).Commands); n != 1 {
t.Fatalf("unexpected command count(0): %d", n)
}
if n := len(c.Logs[1].FSM.Commands); n != 0 {
if n := len(c.Logs[1].FSM.(*FSM).Commands); n != 1 {
t.Fatalf("unexpected command count(1): %d", n)
}
if n := len(c.Logs[2].FSM.Commands); n != 0 {
t.Fatalf("unexpected command count(2): %d", n)
}
// Wait for another heartbeat and all logs should be in sync.
c.Logs[0].Clock.heartbeat()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
if n := len(c.Logs[1].FSM.Commands); n != 1 {
t.Fatalf("unexpected command count(1): %d", n)
}
if n := len(c.Logs[2].FSM.Commands); n != 1 {
if n := len(c.Logs[2].FSM.(*FSM).Commands); n != 1 {
t.Fatalf("unexpected command count(2): %d", n)
}
}
// Ensure that a new leader can be elected.
func TestLog_Elect(t *testing.T) {
c := NewCluster()
c := NewCluster(fsmFunc)
defer c.Close()
// Stop leader.
@ -222,7 +212,7 @@ func TestLog_Elect(t *testing.T) {
}
c.MustWaitUncommitted(index)
c.Logs[1].Clock.heartbeat()
c.Logs[1].HeartbeatUntil(index)
c.Logs[1].Clock.heartbeat()
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
@ -251,108 +241,155 @@ func TestState_String(t *testing.T) {
}
}
func BenchmarkLogApply1(b *testing.B) { benchmarkLogApply(b, 1) }
func BenchmarkLogApply2(b *testing.B) { benchmarkLogApply(b, 2) }
func BenchmarkLogApply3(b *testing.B) { benchmarkLogApply(b, 3) }
// Ensure a cluster of nodes can successfully re-elect while applying commands.
func TestCluster_Elect_RealTime(t *testing.T) {
if testing.Short() {
t.Skip("skip: short mode")
}
// Create a cluster with a real-time clock.
c := NewRealTimeCluster(3, indexFSMFunc)
minIndex := c.Logs[0].AppliedIndex()
commandN := uint64(1000) - minIndex
// Run a loop to continually apply commands.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
leader := c.Logs[0]
for i := uint64(0); i < commandN; i++ {
for {
// Apply entry to leader.
// If not leader, find new leader.
index, err := leader.Apply(make([]byte, 50))
if err == raft.ErrNotLeader {
for _, l := range c.Logs {
if l.State() == raft.Leader {
leader = l
break
}
}
continue
} else if err != nil {
t.Fatalf("apply: index=%d, err=%s", index, err)
} else {
break
}
}
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
}
}()
// Run a loop to periodically kill off nodes.
wg.Add(1)
go func() {
defer wg.Done()
// Wait for nodes to get going.
time.Sleep(500 * time.Millisecond)
// Choose random log.
// i := rand.Intn(len(c.Logs))
i := 0
l := c.Logs[i]
// Restart the log.
path := l.Path()
l.Log.Close()
if err := l.Log.Open(path); err != nil {
t.Fatalf("reopen(%d): %s", i, err)
}
}()
// Wait for all logs to catch up.
wg.Wait()
for i, l := range c.Logs {
if err := l.Wait(commandN + minIndex); err != nil {
t.Errorf("wait(%d): %s", i, err)
}
}
// Verify FSM indicies match.
for i, l := range c.Logs {
if exp, fsm := commandN+minIndex, l.FSM.(*IndexFSM); exp != fsm.index {
t.Errorf("fsm index mismatch(%d): exp=%d, got=%d", i, exp, fsm.index)
}
}
}
func BenchmarkClusterApply1(b *testing.B) { benchmarkClusterApply(b, 1) }
func BenchmarkClusterApply2(b *testing.B) { benchmarkClusterApply(b, 2) }
func BenchmarkClusterApply3(b *testing.B) { benchmarkClusterApply(b, 3) }
// Benchmarks an n-node cluster connected through an in-memory transport.
func benchmarkLogApply(b *testing.B, logN int) {
warnf("== BenchmarkLogApply (%d) ====================================", b.N)
func benchmarkClusterApply(b *testing.B, logN int) {
warnf("== BenchmarkClusterApply (%d) ====================================", b.N)
logs := make([]*raft.Log, logN)
t := NewTransport()
var ptrs []string
for i := 0; i < logN; i++ {
// Create log.
l := raft.NewLog()
l.URL = &url.URL{Host: fmt.Sprintf("log%d", i)}
l.FSM = &BenchmarkFSM{}
l.DebugEnabled = true
l.Transport = t
t.register(l)
// Open log.
if err := l.Open(tempfile()); err != nil {
b.Fatalf("open: %s", err)
}
// Initialize or join.
if i == 0 {
if err := l.Initialize(); err != nil {
b.Fatalf("initialize: %s", err)
}
} else {
if err := l.Join(logs[0].URL); err != nil {
b.Fatalf("initialize: %s", err)
}
}
ptrs = append(ptrs, fmt.Sprintf("%d/%p", i, l))
logs[i] = l
}
warn("LOGS:", strings.Join(ptrs, " "))
c := NewRealTimeCluster(logN, indexFSMFunc)
b.ResetTimer()
// Apply commands to leader.
var index uint64
var err error
for i := 0; i < b.N; i++ {
index, err = logs[0].Apply(make([]byte, 50))
index, err = c.Logs[0].Apply(make([]byte, 50))
if err != nil {
b.Fatalf("apply: %s", err)
}
}
// Wait for all logs to catch up.
for i, l := range logs {
if err := l.Wait(index); err != nil {
b.Fatalf("wait(%d): %s", i, err)
}
for _, l := range c.Logs {
l.MustWait(index)
}
b.StopTimer()
// Verify FSM indicies match.
for i, l := range logs {
if fsm := l.FSM.(*BenchmarkFSM); index != fsm.index {
for i, l := range c.Logs {
if fsm := l.FSM.(*IndexFSM); index != fsm.index {
b.Errorf("fsm index mismatch(%d): exp=%d, got=%d", i, index, fsm.index)
}
}
}
// BenchmarkFSM represents a state machine that records the command count.
type BenchmarkFSM struct {
// IndexFSM represents a state machine that only records the last applied index.
type IndexFSM struct {
index uint64
}
// MustApply updates the index.
func (fsm *BenchmarkFSM) MustApply(entry *raft.LogEntry) { fsm.index = entry.Index }
func (fsm *IndexFSM) MustApply(entry *raft.LogEntry) { fsm.index = entry.Index }
// Index returns the highest applied index.
func (fsm *BenchmarkFSM) Index() (uint64, error) { return fsm.index, nil }
func (fsm *IndexFSM) Index() (uint64, error) { return fsm.index, nil }
// Snapshot writes the FSM's index as the snapshot.
func (fsm *BenchmarkFSM) Snapshot(w io.Writer) (uint64, error) {
func (fsm *IndexFSM) Snapshot(w io.Writer) (uint64, error) {
return fsm.index, binary.Write(w, binary.BigEndian, fsm.index)
}
// Restore reads the snapshot from the reader.
func (fsm *BenchmarkFSM) Restore(r io.Reader) error {
func (fsm *IndexFSM) Restore(r io.Reader) error {
return binary.Read(r, binary.BigEndian, &fsm.index)
}
func indexFSMFunc() raft.FSM { return &IndexFSM{} }
// Cluster represents a collection of nodes that share the same mock clock.
type Cluster struct {
Logs []*Log
}
// NewCluster creates a new 3 log cluster.
func NewCluster() *Cluster {
func NewCluster(fsmFn func() raft.FSM) *Cluster {
c := &Cluster{}
t := NewTransport()
logN := 3
for i := 0; i < logN; i++ {
l := NewLog(&url.URL{Host: fmt.Sprintf("log%d", i)})
l.Log.FSM = fsmFn()
l.Transport = t
c.Logs = append(c.Logs, l)
t.register(l.Log)
@ -383,7 +420,7 @@ func NewCluster() *Cluster {
go func() {
c.Logs[0].MustWaitUncommitted(3)
c.Logs[1].MustWaitUncommitted(3)
c.Logs[0].Clock.heartbeat()
c.Logs[0].HeartbeatUntil(3)
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
@ -400,6 +437,43 @@ func NewCluster() *Cluster {
return c
}
// NewRealTimeCluster a new cluster with n logs.
// All logs use a real-time clock instead of a test clock.
func NewRealTimeCluster(logN int, fsmFn func() raft.FSM) *Cluster {
c := &Cluster{}
t := NewTransport()
for i := 0; i < logN; i++ {
l := NewLog(&url.URL{Host: fmt.Sprintf("log%d", i)})
l.Log.FSM = fsmFn()
l.Clock = nil
l.Log.Clock = raft.NewClock()
l.Transport = t
c.Logs = append(c.Logs, l)
t.register(l.Log)
warnf("Log %s: %p", l.URL.String(), l.Log)
}
warn("")
// Initialize leader.
c.Logs[0].MustOpen()
c.Logs[0].MustInitialize()
// Join remaining nodes.
for i := 1; i < logN; i++ {
c.Logs[i].MustOpen()
c.Logs[i].MustJoin(c.Logs[0].URL)
}
// Ensure nodes are ready.
index := c.Logs[0].Index()
for i := 0; i < logN; i++ {
c.Logs[i].MustWait(index)
}
return c
}
// Close closes all logs in the cluster.
func (c *Cluster) Close() {
for _, l := range c.Logs {
@ -418,7 +492,7 @@ func (c *Cluster) Leader() *Log {
return leader
}
// WaitUncommitted waits until all logs in the cluster have reached a given uncomiitted index.
// WaitCommitted waits until all logs in the cluster have reached a given uncommitted index.
func (c *Cluster) MustWaitUncommitted(index uint64) {
for _, l := range c.Logs {
l.MustWaitUncommitted(index)
@ -437,14 +511,12 @@ func (c *Cluster) flush() {
type Log struct {
*raft.Log
Clock *Clock
FSM *FSM
}
// NewLog returns a new instance of Log.
func NewLog(u *url.URL) *Log {
l := &Log{Log: raft.NewLog(), Clock: NewClock(), FSM: &FSM{}}
l := &Log{Log: raft.NewLog(), Clock: NewClock()}
l.URL = u
l.Log.FSM = l.FSM
l.Log.Clock = l.Clock
l.Rand = seq()
l.DebugEnabled = true
@ -457,6 +529,7 @@ func NewLog(u *url.URL) *Log {
// NewInitializedLog returns a new initialized Node.
func NewInitializedLog(u *url.URL) *Log {
l := NewLog(u)
l.Log.FSM = &FSM{}
l.MustOpen()
l.MustInitialize()
return l
@ -473,13 +546,22 @@ func (l *Log) MustOpen() {
func (l *Log) MustInitialize() {
go func() {
l.MustWaitUncommitted(1)
l.Clock.apply()
if l.Clock != nil {
l.Clock.apply()
}
}()
if err := l.Initialize(); err != nil {
panic("initialize: " + err.Error())
}
}
// MustJoin joins the log to another log. Panic on error.
func (l *Log) MustJoin(u *url.URL) {
if err := l.Join(u); err != nil {
panic("join: " + err.Error())
}
}
// Close closes the log and HTTP server.
func (l *Log) Close() error {
defer os.RemoveAll(l.Log.Path())
@ -487,6 +569,20 @@ func (l *Log) Close() error {
return nil
}
// MustWaits waits for at least a given applied index. Panic on error.
func (l *Log) MustWait(index uint64) {
if err := l.Log.Wait(index); err != nil {
panic(l.URL.String() + " wait: " + err.Error())
}
}
// MustCommitted waits for at least a given committed index. Panic on error.
func (l *Log) MustWaitCommitted(index uint64) {
if err := l.Log.WaitCommitted(index); err != nil {
panic(l.URL.String() + " wait committed: " + err.Error())
}
}
// MustWaitUncommitted waits for at least a given uncommitted index. Panic on error.
func (l *Log) MustWaitUncommitted(index uint64) {
if err := l.Log.WaitUncommitted(index); err != nil {
@ -494,6 +590,17 @@ func (l *Log) MustWaitUncommitted(index uint64) {
}
}
// HeartbeatUtil continues to heartbeat until an index is committed.
func (l *Log) HeartbeatUntil(index uint64) {
for {
time.Sleep(1 * time.Millisecond)
l.Clock.heartbeat()
if l.CommitIndex() >= index {
return
}
}
}
// FSM represents a simple state machine that records all commands.
type FSM struct {
MaxIndex uint64
@ -535,18 +642,7 @@ func (fsm *FSM) Restore(r io.Reader) error {
return nil
}
// MockFSM represents a state machine that can be mocked out.
type MockFSM struct {
ApplyFunc func(*raft.LogEntry) error
IndexFunc func() (uint64, error)
SnapshotFunc func(w io.Writer) (index uint64, err error)
RestoreFunc func(r io.Reader) error
}
func (fsm *MockFSM) Apply(e *raft.LogEntry) error { return fsm.ApplyFunc(e) }
func (fsm *MockFSM) Index() (uint64, error) { return fsm.IndexFunc() }
func (fsm *MockFSM) Snapshot(w io.Writer) (uint64, error) { return fsm.SnapshotFunc(w) }
func (fsm *MockFSM) Restore(r io.Reader) error { return fsm.RestoreFunc(r) }
func fsmFunc() raft.FSM { return &FSM{} }
// seq implements the raft.Log#Rand interface and returns incrementing ints.
func seq() func() int64 {
@ -587,3 +683,13 @@ func warnf(msg string, v ...interface{}) {
fmt.Fprintf(os.Stderr, msg+"\n", v...)
}
}
// u64tob converts a uint64 into an 8-byte slice.
func u64tob(v uint64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, v)
return b
}
// btou64 converts an 8-byte slice into an uint64.
func btou64(b []byte) uint64 { return binary.BigEndian.Uint64(b) }

View File

@ -15,7 +15,7 @@ import (
type HTTPTransport struct{}
// Join requests membership into a node's cluster.
func (t *HTTPTransport) Join(uri *url.URL, nodeURL *url.URL) (uint64, *Config, error) {
func (t *HTTPTransport) Join(uri *url.URL, nodeURL *url.URL) (uint64, uint64, *Config, error) {
// Construct URL.
u := *uri
u.Path = path.Join(u.Path, "raft/join")
@ -24,29 +24,34 @@ func (t *HTTPTransport) Join(uri *url.URL, nodeURL *url.URL) (uint64, *Config, e
// Send HTTP request.
resp, err := http.Get(u.String())
if err != nil {
return 0, nil, err
return 0, 0, nil, err
}
defer func() { _ = resp.Body.Close() }()
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return 0, nil, errors.New(s)
return 0, 0, nil, errors.New(s)
}
// Parse returned id.
idString := resp.Header.Get("X-Raft-ID")
id, err := strconv.ParseUint(idString, 10, 64)
id, err := strconv.ParseUint(resp.Header.Get("X-Raft-ID"), 10, 64)
if err != nil {
return 0, nil, fmt.Errorf("invalid id: %q", idString)
return 0, 0, nil, fmt.Errorf("invalid id: %q", resp.Header.Get("X-Raft-ID"))
}
// Parse returned id.
leaderID, err := strconv.ParseUint(resp.Header.Get("X-Raft-Leader-ID"), 10, 64)
if err != nil {
return 0, 0, nil, fmt.Errorf("invalid leader id: %q", resp.Header.Get("X-Raft-Leader-ID"))
}
// Unmarshal config.
var config *Config
if err := json.NewDecoder(resp.Body).Decode(&config); err != nil {
return 0, nil, fmt.Errorf("config unmarshal: %s", err)
return 0, 0, nil, fmt.Errorf("config unmarshal: %s", err)
}
return id, config, nil
return id, leaderID, config, nil
}
// Leave removes a node from a cluster's membership.
@ -72,7 +77,7 @@ func (t *HTTPTransport) Leave(uri *url.URL, id uint64) error {
}
// Heartbeat checks the status of a follower.
func (t *HTTPTransport) Heartbeat(uri *url.URL, term, commitIndex, leaderID uint64) (uint64, uint64, error) {
func (t *HTTPTransport) Heartbeat(uri *url.URL, term, commitIndex, leaderID uint64) (uint64, error) {
// Construct URL.
u := *uri
u.Path = path.Join(u.Path, "raft/heartbeat")
@ -87,7 +92,7 @@ func (t *HTTPTransport) Heartbeat(uri *url.URL, term, commitIndex, leaderID uint
// Send HTTP request.
resp, err := http.Get(u.String())
if err != nil {
return 0, 0, err
return 0, err
}
_ = resp.Body.Close()
@ -95,22 +100,15 @@ func (t *HTTPTransport) Heartbeat(uri *url.URL, term, commitIndex, leaderID uint
newIndexString := resp.Header.Get("X-Raft-Index")
newIndex, err := strconv.ParseUint(newIndexString, 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid index: %q", newIndexString)
}
// Parse returned term.
newTermString := resp.Header.Get("X-Raft-Term")
newTerm, err := strconv.ParseUint(newTermString, 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid term: %q", newTermString)
return 0, fmt.Errorf("invalid index: %q", newIndexString)
}
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return newIndex, newTerm, errors.New(s)
return newIndex, errors.New(s)
}
return newIndex, newTerm, nil
return newIndex, nil
}
// ReadFrom streams the log from a leader.
@ -142,7 +140,7 @@ func (t *HTTPTransport) ReadFrom(uri *url.URL, id, term, index uint64) (io.ReadC
}
// RequestVote requests a vote for a candidate in a given term.
func (t *HTTPTransport) RequestVote(uri *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
func (t *HTTPTransport) RequestVote(uri *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) error {
// Construct URL.
u := *uri
u.Path = path.Join(u.Path, "raft/vote")
@ -158,21 +156,14 @@ func (t *HTTPTransport) RequestVote(uri *url.URL, term, candidateID, lastLogInde
// Send HTTP request.
resp, err := http.Get(u.String())
if err != nil {
return 0, err
return err
}
_ = resp.Body.Close()
// Parse returned term.
newTermString := resp.Header.Get("X-Raft-Term")
newTerm, err := strconv.ParseUint(newTermString, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid term: %q", newTermString)
}
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return newTerm, errors.New(s)
return errors.New(s)
}
return newTerm, nil
return nil
}

View File

@ -27,17 +27,20 @@ func TestHTTPTransport_Join(t *testing.T) {
t.Fatalf("unexpected term: %q", s)
}
w.Header().Set("X-Raft-ID", "1")
w.Header().Set("X-Raft-Leader-ID", "2")
w.Write([]byte(`{}`))
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
id, config, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
id, leaderID, config, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if id != 1 {
t.Fatalf("unexpected id: %d", id)
} else if leaderID != 2 {
t.Fatalf("unexpected leader id: %d", leaderID)
} else if config == nil {
t.Fatalf("unexpected config")
}
@ -45,7 +48,7 @@ func TestHTTPTransport_Join(t *testing.T) {
// Ensure that joining a server that doesn't exist returns an error.
func TestHTTPTransport_Join_ErrConnectionRefused(t *testing.T) {
_, _, err := (&raft.HTTPTransport{}).Join(&url.URL{Scheme: "http", Host: "localhost:27322"}, &url.URL{Host: "local"})
_, _, _, err := (&raft.HTTPTransport{}).Join(&url.URL{Scheme: "http", Host: "localhost:27322"}, &url.URL{Host: "local"})
if err == nil || !strings.Contains(err.Error(), "connection refused") {
t.Fatalf("unexpected error: %s", err)
}
@ -61,7 +64,7 @@ func TestHTTPTransport_Join_ErrInvalidID(t *testing.T) {
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
_, _, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `invalid id: "xxx"` {
t.Fatalf("unexpected error: %s", err)
}
@ -72,13 +75,14 @@ func TestHTTPTransport_Join_ErrInvalidConfig(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-ID", "1")
w.Header().Set("X-Raft-Leader-ID", "1")
w.Write([]byte(`{`))
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
_, _, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `config unmarshal: unexpected EOF` {
t.Fatalf("unexpected error: %s", err)
}
@ -95,7 +99,7 @@ func TestHTTPTransport_Join_Err(t *testing.T) {
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
_, _, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `oh no` {
t.Fatalf("unexpected error: %s", err)
}
@ -161,20 +165,17 @@ func TestHTTPTransport_Heartbeat(t *testing.T) {
t.Fatalf("unexpected leader id: %q", leaderID)
}
w.Header().Set("X-Raft-Index", "4")
w.Header().Set("X-Raft-Term", "5")
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
// Execute heartbeat against test server.
u, _ := url.Parse(s.URL)
newIndex, newTerm, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
newIndex, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if newIndex != 4 {
t.Fatalf("unexpected new index: %d", newIndex)
} else if newTerm != 5 {
t.Fatalf("unexpected new term: %d", newTerm)
}
}
@ -182,25 +183,22 @@ func TestHTTPTransport_Heartbeat(t *testing.T) {
func TestHTTPTransport_Heartbeat_Err(t *testing.T) {
var tests = []struct {
index string
term string
errstr string
err string
}{
{index: "", term: "", err: `invalid index: ""`},
{index: "1000", term: "", err: `invalid term: ""`},
{index: "1", term: "2", errstr: "bad heartbeat", err: `bad heartbeat`},
{index: "", err: `invalid index: ""`},
{index: "1", errstr: "bad heartbeat", err: `bad heartbeat`},
}
for i, tt := range tests {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-Index", tt.index)
w.Header().Set("X-Raft-Term", tt.term)
w.Header().Set("X-Raft-Error", tt.errstr)
w.WriteHeader(http.StatusOK)
}))
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
_, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
if err == nil {
t.Errorf("%d. expected error", i)
} else if tt.err != err.Error() {
@ -213,7 +211,7 @@ func TestHTTPTransport_Heartbeat_Err(t *testing.T) {
// Ensure an HTTP heartbeat to a stopped server returns an error.
func TestHTTPTransport_Heartbeat_ErrConnectionRefused(t *testing.T) {
u, _ := url.Parse("http://localhost:41932")
_, _, err := (&raft.HTTPTransport{}).Heartbeat(u, 0, 0, 0)
_, err := (&raft.HTTPTransport{}).Heartbeat(u, 0, 0, 0)
if err == nil {
t.Fatal("expected error")
} else if !strings.Contains(err.Error(), `connection refused`) {
@ -303,35 +301,14 @@ func TestHTTPTransport_RequestVote(t *testing.T) {
if lastLogTerm := r.FormValue("lastLogTerm"); lastLogTerm != `4` {
t.Fatalf("unexpected last log term: %v", lastLogTerm)
}
w.Header().Set("X-Raft-Term", "5")
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
// Execute heartbeat against test server.
u, _ := url.Parse(s.URL)
newTerm, err := (&raft.HTTPTransport{}).RequestVote(u, 1, 2, 3, 4)
if err != nil {
if err := (&raft.HTTPTransport{}).RequestVote(u, 1, 2, 3, 4); err != nil {
t.Fatalf("unexpected error: %s", err)
} else if newTerm != 5 {
t.Fatalf("unexpected new term: %d", newTerm)
}
}
// Ensure that a returned vote with an invalid term returns an error.
func TestHTTPTransport_RequestVote_ErrInvalidTerm(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-Term", `xxx`)
}))
defer s.Close()
u, _ := url.Parse(s.URL)
_, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil {
t.Errorf("expected error")
} else if err.Error() != `invalid term: "xxx"` {
t.Errorf("unexpected error: %s", err)
}
}
@ -345,8 +322,7 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) {
defer s.Close()
u, _ := url.Parse(s.URL)
_, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil {
if err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0); err == nil {
t.Errorf("expected error")
} else if err.Error() != `already voted` {
t.Errorf("unexpected error: %s", err)
@ -356,8 +332,7 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) {
// Ensure that requesting a vote over HTTP to a stopped server returns an error.
func TestHTTPTransport_RequestVote_ErrConnectionRefused(t *testing.T) {
u, _ := url.Parse("http://localhost:41932")
_, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil {
if err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0); err == nil {
t.Fatal("expected error")
} else if !strings.Contains(err.Error(), `connection refused`) {
t.Fatalf("unexpected error: %s", err)
@ -389,10 +364,10 @@ func (t *Transport) log(u *url.URL) (*raft.Log, error) {
}
// Join calls the AddPeer method on the target log.
func (t *Transport) Join(u *url.URL, nodeURL *url.URL) (uint64, *raft.Config, error) {
func (t *Transport) Join(u *url.URL, nodeURL *url.URL) (uint64, uint64, *raft.Config, error) {
l, err := t.log(u)
if err != nil {
return 0, nil, err
return 0, 0, nil, err
}
return l.AddPeer(nodeURL)
}
@ -407,10 +382,10 @@ func (t *Transport) Leave(u *url.URL, id uint64) error {
}
// Heartbeat calls the Heartbeat method on the target log.
func (t *Transport) Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (lastIndex, currentTerm uint64, err error) {
func (t *Transport) Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (lastIndex uint64, err error) {
l, err := t.log(u)
if err != nil {
return 0, 0, err
return 0, err
}
return l.Heartbeat(term, commitIndex, leaderID)
}
@ -434,10 +409,10 @@ func (t *Transport) ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser,
}
// RequestVote calls RequestVote() on the target log.
func (t *Transport) RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
func (t *Transport) RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) error {
l, err := t.log(u)
if err != nil {
return 0, err
return err
}
return l.RequestVote(term, candidateID, lastLogIndex, lastLogTerm)
}