Merge pull request #1515 from influxdb/raft

Refactor raft package
pull/1517/head
Ben Johnson 2015-02-05 07:50:07 -07:00
commit 046a69dd8f
11 changed files with 953 additions and 832 deletions

View File

@ -190,7 +190,7 @@ func (b *Broker) load() error {
// save persists the broker metadata to disk.
func (b *Broker) save() error {
if b.path == "" {
return fmt.Errorf("broker not open")
return ErrClosed
}
// Calculate header under lock.
@ -216,7 +216,7 @@ func (b *Broker) save() error {
// mustSave persists the broker metadata to disk. Panic on error.
func (b *Broker) mustSave() {
if err := b.save(); err != nil {
if err := b.save(); err != nil && err != ErrClosed {
panic(err.Error())
}
}

View File

@ -11,7 +11,7 @@ import (
// Handler represents an HTTP handler by the broker.
type Handler struct {
raftHandler *raft.HTTPHandler
raftHandler *raft.Handler
broker *Broker
}
@ -30,7 +30,7 @@ func (h *Handler) SetBroker(b *Broker) {
h.broker = b
if b != nil {
h.raftHandler = raft.NewHTTPHandler(b.log)
h.raftHandler = &raft.Handler{Log: b.log}
} else {
h.raftHandler = nil
}

View File

@ -1,280 +1,93 @@
package raft
import (
"sort"
"sync"
"time"
)
// Clock represents an interface to the functions in the standard library time
// package. Two implementations are available in the clock package. The first
// is a real-time clock which simply wraps the time package's functions. The
// second is a mock clock which will only make forward progress when
// programmatically adjusted.
type Clock interface {
Add(d time.Duration)
After(d time.Duration) <-chan time.Time
AfterFunc(d time.Duration, f func()) *Timer
Now() time.Time
Sleep(d time.Duration)
Tick(d time.Duration) <-chan time.Time
Ticker(d time.Duration) *Ticker
Timer(d time.Duration) *Timer
const (
// DefaultApplyInterval is the default time between checks to apply commands.
DefaultApplyInterval = 10 * time.Millisecond
// DefaultElectionTimeout is the default time before starting an election.
DefaultElectionTimeout = 500 * time.Millisecond
// DefaultHeartbeatInterval is the default time to wait between heartbeats.
DefaultHeartbeatInterval = 150 * time.Millisecond
// DefaultReconnectTimeout is the default time to wait before reconnecting.
DefaultReconnectTimeout = 10 * time.Millisecond
// DefaultWaitInterval is the default time to wait log sync.
DefaultWaitInterval = 1 * time.Millisecond
)
// Clock implements an interface to the real-time clock.
type Clock struct {
ApplyInterval time.Duration
ElectionTimeout time.Duration
HeartbeatInterval time.Duration
ReconnectTimeout time.Duration
WaitInterval time.Duration
}
// clock implements a real-time clock by simply wrapping the time package functions.
type clock struct{}
func (c *clock) Add(d time.Duration) { panic("real-time clock cannot manually adjust") }
func (c *clock) After(d time.Duration) <-chan time.Time { return time.After(d) }
func (c *clock) AfterFunc(d time.Duration, f func()) *Timer {
return &Timer{timer: time.AfterFunc(d, f)}
// NewClock returns a instance of Clock with defaults set.
func NewClock() *Clock {
return &Clock{
ApplyInterval: DefaultApplyInterval,
ElectionTimeout: DefaultElectionTimeout,
HeartbeatInterval: DefaultHeartbeatInterval,
ReconnectTimeout: DefaultReconnectTimeout,
WaitInterval: DefaultWaitInterval,
}
}
func (c *clock) Now() time.Time { return time.Now() }
// AfterApplyInterval returns a channel that fires after the apply interval.
func (c *Clock) AfterApplyInterval() <-chan chan struct{} { return newClockChan(c.ApplyInterval) }
func (c *clock) Sleep(d time.Duration) { time.Sleep(d) }
// AfterElectionTimeout returns a channel that fires after the election timeout.
func (c *Clock) AfterElectionTimeout() <-chan chan struct{} { return newClockChan(c.ElectionTimeout) }
func (c *clock) Tick(d time.Duration) <-chan time.Time { return time.Tick(d) }
// AfterHeartbeatInterval returns a channel that fires after the heartbeat interval.
func (c *Clock) AfterHeartbeatInterval() <-chan chan struct{} {
return newClockChan(c.HeartbeatInterval)
}
func (c *clock) Ticker(d time.Duration) *Ticker {
t := time.NewTicker(d)
// AfterReconnectTimeout returns a channel that fires after the reconnection timeout.
func (c *Clock) AfterReconnectTimeout() <-chan chan struct{} { return newClockChan(c.ReconnectTimeout) }
// AfterWaitInterval returns a channel that fires after the wait interval.
func (c *Clock) AfterWaitInterval() <-chan chan struct{} { return newClockChan(c.WaitInterval) }
// HeartbeatTicker returns a Ticker that ticks every heartbeat.
func (c *Clock) HeartbeatTicker() *Ticker {
t := time.NewTicker(c.HeartbeatInterval)
return &Ticker{C: t.C, ticker: t}
}
func (c *clock) Timer(d time.Duration) *Timer {
t := time.NewTimer(d)
return &Timer{C: t.C, timer: t}
}
// mockClock represents a mock clock that only moves forward programmically.
// It can be preferable to a real-time clock when testing time-based functionality.
type mockClock struct {
mu sync.Mutex
now time.Time // current time
timers clockTimers // tickers & timers
}
// NewMockClock returns an instance of a mock clock.
// The current time of the mock clock on initialization is the Unix epoch.
func NewMockClock() Clock {
return &mockClock{now: time.Unix(0, 0)}
}
// Add moves the current time of the mock clock forward by the duration.
// This should only be called from a single goroutine at a time.
func (m *mockClock) Add(d time.Duration) {
// Calculate the final current time.
t := m.now.Add(d)
// Continue to execute timers until there are no more before the new time.
for {
if !m.runNextTimer(t) {
break
}
}
// Ensure that we end with the new time.
m.mu.Lock()
m.now = t
m.mu.Unlock()
// Give a small buffer to make sure the other goroutines get handled.
gosched()
}
// runNextTimer executes the next timer in chronological order and moves the
// current time to the timer's next tick time. The next time is not executed if
// it's next time if after the max time. Returns true if a timer is executed.
func (m *mockClock) runNextTimer(max time.Time) bool {
m.mu.Lock()
// Sort timers by time.
sort.Sort(m.timers)
// If we have no more timers then exit.
if len(m.timers) == 0 {
m.mu.Unlock()
return false
}
// Retrieve next timer. Exit if next tick is after new time.
t := m.timers[0]
if t.Next().After(max) {
m.mu.Unlock()
return false
}
// Move "now" forward and unlock clock.
m.now = t.Next()
m.mu.Unlock()
// Execute timer.
t.Tick(m.now)
return true
}
// After waits for the duration to elapse and then sends the current time on the returned channel.
func (m *mockClock) After(d time.Duration) <-chan time.Time {
return m.Timer(d).C
}
// AfterFunc waits for the duration to elapse and then executes a function.
// A Timer is returned that can be stopped.
func (m *mockClock) AfterFunc(d time.Duration, f func()) *Timer {
t := m.Timer(d)
t.C = nil
t.fn = f
return t
}
// Now returns the current wall time on the mock clock.
func (m *mockClock) Now() time.Time {
m.mu.Lock()
defer m.mu.Unlock()
return m.now
}
// Sleep pauses the goroutine for the given duration on the mock clock.
// The clock must be moved forward in a separate goroutine.
func (m *mockClock) Sleep(d time.Duration) {
<-m.After(d)
}
// Tick is a convenience function for Ticker().
// It will return a ticker channel that cannot be stopped.
func (m *mockClock) Tick(d time.Duration) <-chan time.Time {
return m.Ticker(d).C
}
// Ticker creates a new instance of Ticker.
func (m *mockClock) Ticker(d time.Duration) *Ticker {
m.mu.Lock()
defer m.mu.Unlock()
ch := make(chan time.Time, 1)
t := &Ticker{
C: ch,
c: ch,
mock: m,
d: d,
next: m.now.Add(d),
}
m.timers = append(m.timers, (*internalTicker)(t))
return t
}
// Timer creates a new instance of Timer.
func (m *mockClock) Timer(d time.Duration) *Timer {
m.mu.Lock()
defer m.mu.Unlock()
ch := make(chan time.Time, 1)
t := &Timer{
C: ch,
c: ch,
mock: m,
next: m.now.Add(d),
}
m.timers = append(m.timers, (*internalTimer)(t))
return t
}
func (m *mockClock) removeClockTimer(t clockTimer) {
m.mu.Lock()
defer m.mu.Unlock()
for i, timer := range m.timers {
if timer == t {
copy(m.timers[i:], m.timers[i+1:])
m.timers[len(m.timers)-1] = nil
m.timers = m.timers[:len(m.timers)-1]
break
}
}
sort.Sort(m.timers)
}
// clockTimer represents an object with an associated start time.
type clockTimer interface {
Next() time.Time
Tick(time.Time)
}
// clockTimers represents a list of sortable timers.
type clockTimers []clockTimer
func (a clockTimers) Len() int { return len(a) }
func (a clockTimers) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a clockTimers) Less(i, j int) bool { return a[i].Next().Before(a[j].Next()) }
// Timer represents a single event.
// The current time will be sent on C, unless the timer was created by AfterFunc.
type Timer struct {
C <-chan time.Time
c chan time.Time
timer *time.Timer // realtime impl, if set
next time.Time // next tick time
mock *mockClock // mock clock, if set
fn func() // AfterFunc function, if set
}
// Stop turns off the ticker.
func (t *Timer) Stop() {
if t.timer != nil {
t.timer.Stop()
} else {
t.mock.removeClockTimer((*internalTimer)(t))
}
}
type internalTimer Timer
func (t *internalTimer) Next() time.Time { return t.next }
func (t *internalTimer) Tick(now time.Time) {
if t.fn != nil {
t.fn()
} else {
t.c <- now
}
t.mock.removeClockTimer((*internalTimer)(t))
gosched()
}
// Now returns the current wall clock time.
func (c *Clock) Now() time.Time { return time.Now() }
// Ticker holds a channel that receives "ticks" at regular intervals.
type Ticker struct {
C <-chan time.Time
c chan time.Time
ticker *time.Ticker // realtime impl, if set
next time.Time // next tick time
mock *mockClock // mock clock, if set
d time.Duration // time between ticks
ticker *time.Ticker // realtime impl, if set
}
// Stop turns off the ticker.
func (t *Ticker) Stop() {
if t.ticker != nil {
t.ticker.Stop()
} else {
t.mock.removeClockTimer((*internalTicker)(t))
}
}
type internalTicker Ticker
func (t *internalTicker) Next() time.Time { return t.next }
func (t *internalTicker) Tick(now time.Time) {
select {
case t.c <- now:
case <-time.After(goschedTimeout):
}
t.next = now.Add(t.d)
gosched()
}
// goschedTimeout is the amount of wall time to sleep during goroutine scheduling.
var goschedTimeout = 1 * time.Millisecond
// Sleep momentarily so that other goroutines can process.
func gosched() {
time.Sleep(goschedTimeout)
// newClockChan returns a channel that sends a channel after a given duration.
// The channel being sent, over the channel that is returned, can be used to
// notify the sender when an action is done.
func newClockChan(d time.Duration) <-chan chan struct{} {
ch := make(chan chan struct{})
go func() {
time.Sleep(d)
ch <- make(chan struct{})
}()
return ch
}

80
raft/clock_test.go Normal file
View File

@ -0,0 +1,80 @@
package raft_test
import (
"flag"
"time"
)
var (
goschedTimeout = flag.Duration("gosched", 100*time.Millisecond, "gosched() delay")
)
// DefaultTime represents the time that the test clock is initialized to.
// Defaults to midnight on Jan 1, 2000 UTC
var DefaultTime = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
// Clock represents a testable clock.
type Clock struct {
now time.Time
applyChan chan chan struct{}
electionChan chan chan struct{}
heartbeatChan chan chan struct{}
reconnectChan chan chan struct{}
NowFunc func() time.Time
AfterApplyIntervalFunc func() <-chan chan struct{}
AfterElectionTimeoutFunc func() <-chan chan struct{}
AfterHeartbeatIntervalFunc func() <-chan chan struct{}
AfterReconnectTimeoutFunc func() <-chan chan struct{}
}
// NewClock returns an instance of Clock with default.
func NewClock() *Clock {
c := &Clock{
now: DefaultTime,
applyChan: make(chan chan struct{}, 0),
electionChan: make(chan chan struct{}, 0),
heartbeatChan: make(chan chan struct{}, 0),
reconnectChan: make(chan chan struct{}, 0),
}
// Set default functions.
c.NowFunc = func() time.Time { return c.now }
c.AfterApplyIntervalFunc = func() <-chan chan struct{} { return c.applyChan }
c.AfterElectionTimeoutFunc = func() <-chan chan struct{} { return c.electionChan }
c.AfterHeartbeatIntervalFunc = func() <-chan chan struct{} { return c.heartbeatChan }
c.AfterReconnectTimeoutFunc = func() <-chan chan struct{} { return c.reconnectChan }
return c
}
func (c *Clock) apply() {
ch := make(chan struct{}, 0)
c.applyChan <- ch
<-ch
}
func (c *Clock) election() {
ch := make(chan struct{}, 0)
c.electionChan <- ch
<-ch
}
func (c *Clock) heartbeat() {
ch := make(chan struct{}, 0)
c.heartbeatChan <- ch
<-ch
}
func (c *Clock) reconnect() {
ch := make(chan struct{}, 0)
c.reconnectChan <- ch
<-ch
}
func (c *Clock) Now() time.Time { return c.NowFunc() }
func (c *Clock) AfterApplyInterval() <-chan chan struct{} { return c.AfterApplyIntervalFunc() }
func (c *Clock) AfterElectionTimeout() <-chan chan struct{} { return c.AfterElectionTimeoutFunc() }
func (c *Clock) AfterHeartbeatInterval() <-chan chan struct{} { return c.AfterHeartbeatIntervalFunc() }
func (c *Clock) AfterReconnectTimeout() <-chan chan struct{} { return c.AfterReconnectTimeoutFunc() }
func gosched() { time.Sleep(*goschedTimeout) }

View File

@ -9,18 +9,19 @@ import (
"strconv"
)
// HTTPHandler represents an HTTP endpoint for Raft to communicate over.
type HTTPHandler struct {
log *Log
}
// NewHTTPHandler returns a new instance of HTTPHandler associated with a log.
func NewHTTPHandler(log *Log) *HTTPHandler {
return &HTTPHandler{log: log}
// Handler represents an HTTP endpoint for Raft to communicate over.
type Handler struct {
Log interface {
AddPeer(u *url.URL) (uint64, *Config, error)
RemovePeer(id uint64) error
Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error)
WriteEntriesTo(w io.Writer, id, term, index uint64) error
RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
}
}
// ServeHTTP handles all incoming HTTP requests.
func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "join":
h.serveJoin(w, r)
@ -40,7 +41,7 @@ func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// serveJoin serves a Raft membership addition to the underlying log.
func (h *HTTPHandler) serveJoin(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveJoin(w http.ResponseWriter, r *http.Request) {
// TODO(benbjohnson): Redirect to leader.
// Parse argument.
@ -59,7 +60,7 @@ func (h *HTTPHandler) serveJoin(w http.ResponseWriter, r *http.Request) {
}
// Add peer to the log.
id, config, err := h.log.AddPeer(u)
id, config, err := h.Log.AddPeer(u)
if err != nil {
w.Header().Set("X-Raft-Error", err.Error())
w.WriteHeader(http.StatusInternalServerError)
@ -75,7 +76,7 @@ func (h *HTTPHandler) serveJoin(w http.ResponseWriter, r *http.Request) {
}
// serveLeave removes a member from the cluster.
func (h *HTTPHandler) serveLeave(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveLeave(w http.ResponseWriter, r *http.Request) {
// TODO(benbjohnson): Redirect to leader.
// Parse arguments.
@ -87,7 +88,7 @@ func (h *HTTPHandler) serveLeave(w http.ResponseWriter, r *http.Request) {
}
// Remove a peer from the log.
if err := h.log.RemovePeer(id); err != nil {
if err := h.Log.RemovePeer(id); err != nil {
w.Header().Set("X-Raft-Error", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
@ -97,7 +98,7 @@ func (h *HTTPHandler) serveLeave(w http.ResponseWriter, r *http.Request) {
}
// serveHeartbeat serves a Raft heartbeat to the underlying log.
func (h *HTTPHandler) serveHeartbeat(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveHeartbeat(w http.ResponseWriter, r *http.Request) {
var err error
var term, commitIndex, leaderID uint64
@ -119,7 +120,7 @@ func (h *HTTPHandler) serveHeartbeat(w http.ResponseWriter, r *http.Request) {
}
// Execute heartbeat on the log.
currentIndex, currentTerm, err := h.log.Heartbeat(term, commitIndex, leaderID)
currentIndex, currentTerm, err := h.Log.Heartbeat(term, commitIndex, leaderID)
// Return current term and index.
w.Header().Set("X-Raft-Index", strconv.FormatUint(currentIndex, 10))
@ -136,7 +137,7 @@ func (h *HTTPHandler) serveHeartbeat(w http.ResponseWriter, r *http.Request) {
}
// serveStream provides a streaming log endpoint.
func (h *HTTPHandler) serveStream(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveStream(w http.ResponseWriter, r *http.Request) {
var err error
var id, index, term uint64
@ -166,7 +167,7 @@ func (h *HTTPHandler) serveStream(w http.ResponseWriter, r *http.Request) {
// TODO(benbjohnson): Redirect to leader.
// Write to the response.
if err := h.log.WriteEntriesTo(w, id, term, index); err != nil && err != io.EOF {
if err := h.Log.WriteEntriesTo(w, id, term, index); err != nil && err != io.EOF {
w.Header().Set("X-Raft-Error", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
@ -174,7 +175,7 @@ func (h *HTTPHandler) serveStream(w http.ResponseWriter, r *http.Request) {
}
// serveRequestVote serves a vote request to the underlying log.
func (h *HTTPHandler) serveRequestVote(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveRequestVote(w http.ResponseWriter, r *http.Request) {
var err error
var term, candidateID, lastLogIndex, lastLogTerm uint64
@ -201,7 +202,7 @@ func (h *HTTPHandler) serveRequestVote(w http.ResponseWriter, r *http.Request) {
}
// Execute heartbeat on the log.
currentTerm, err := h.log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm)
currentTerm, err := h.Log.RequestVote(term, candidateID, lastLogIndex, lastLogTerm)
// Return current term and index.
w.Header().Set("X-Raft-Term", strconv.FormatUint(currentTerm, 10))

View File

@ -1,67 +1,79 @@
package raft_test
import (
"encoding/binary"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"time"
"github.com/influxdb/influxdb/raft"
)
// Ensure a node can join a cluster over HTTP.
func TestHTTPHandler_HandleJoin(t *testing.T) {
n := NewInitNode()
defer n.Close()
func TestHandler_HandleJoin(t *testing.T) {
h := NewHandler()
h.AddPeerFunc = func(u *url.URL) (uint64, *raft.Config, error) {
if u.String() != "http://localhost:1000" {
t.Fatalf("unexpected url: %s", u)
}
return 2, &raft.Config{}, nil
}
s := httptest.NewServer(h)
defer s.Close()
// Send request to join cluster.
go func() { n.Clock().Add(n.Log.ApplyInterval) }()
resp, err := http.Get(n.Server.URL + "/join?url=" + url.QueryEscape("http://localhost:1000"))
resp, err := http.Get(s.URL + "/join?url=" + url.QueryEscape("http://localhost:1000"))
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d: %s", resp.StatusCode, resp.Header.Get("X-Raft-Error"))
}
if s := resp.Header.Get("X-Raft-Error"); s != "" {
} else if s := resp.Header.Get("X-Raft-Error"); s != "" {
t.Fatalf("unexpected raft error: %s", s)
}
if s := resp.Header.Get("X-Raft-ID"); s != "2" {
} else if s = resp.Header.Get("X-Raft-ID"); s != "2" {
t.Fatalf("unexpected raft id: %s", s)
}
}
// Ensure a heartbeat can be sent over HTTP.
func TestHTTPHandler_HandleHeartbeat(t *testing.T) {
t.Skip()
n := NewInitNode()
defer n.Close()
func TestHandler_HandleHeartbeat(t *testing.T) {
h := NewHandler()
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
if term != 1 {
t.Fatalf("unexpected term: %d", term)
} else if commitIndex != 2 {
t.Fatalf("unexpected commit index: %d", commitIndex)
} else if leaderID != 3 {
t.Fatalf("unexpected leader id: %d", leaderID)
}
return 4, 5, nil
}
s := httptest.NewServer(h)
defer s.Close()
// Send heartbeat.
resp, err := http.Get(n.Server.URL + "/heartbeat?term=1&commitIndex=0&leaderID=1")
resp, err := http.Get(s.URL + "/heartbeat?term=1&commitIndex=2&leaderID=3")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != "" {
} else if s := resp.Header.Get("X-Raft-Error"); s != "" {
t.Fatalf("unexpected raft error: %s", s)
}
if s := resp.Header.Get("X-Raft-Index"); s != "1" {
} else if s = resp.Header.Get("X-Raft-Index"); s != "4" {
t.Fatalf("unexpected raft index: %s", s)
}
if s := resp.Header.Get("X-Raft-Term"); s != "1" {
} else if s = resp.Header.Get("X-Raft-Term"); s != "5" {
t.Fatalf("unexpected raft term: %s", s)
}
}
// Ensure that sending a heartbeat with an invalid term returns an error.
func TestHTTPHandler_HandleHeartbeat_Error(t *testing.T) {
// TODO corylanou: racy failing test. Stack trace here: https://gist.github.com/corylanou/5864e2058656fd6e542f
t.Skip()
func TestHandler_HandleHeartbeat_Error(t *testing.T) {
h := NewHandler()
s := httptest.NewServer(h)
defer s.Close()
var tests = []struct {
query string
@ -72,55 +84,59 @@ func TestHTTPHandler_HandleHeartbeat_Error(t *testing.T) {
{query: `term=1&commitIndex=0&leaderID=XXX`, err: `invalid leader id`},
}
for i, tt := range tests {
func() {
n := NewInitNode()
defer n.Close()
// Send heartbeat.
resp, err := http.Get(n.Server.URL + "/heartbeat?" + tt.query)
defer resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}()
resp, err := http.Get(s.URL + "/heartbeat?" + tt.query)
resp.Body.Close()
if err != nil {
t.Errorf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != http.StatusBadRequest {
t.Errorf("%d. unexpected status: %d", i, resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Errorf("%d. unexpected raft error: %s", i, s)
}
}
}
// Ensure that sending a heartbeat to a closed log returns an error.
func TestHTTPHandler_HandleHeartbeat_ErrClosed(t *testing.T) {
// TODO corylanou: racy failing test. Stack trace here:https://gist.github.com/corylanou/02ea4cc47a479df39706
t.Skip()
n := NewInitNode()
n.Log.Close()
defer n.Close()
func TestHandler_HandleHeartbeat_ErrClosed(t *testing.T) {
h := NewHandler()
h.HeartbeatFunc = func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
return 0, 0, raft.ErrClosed
}
s := httptest.NewServer(h)
defer s.Close()
// Send heartbeat.
resp, err := http.Get(n.Server.URL + "/heartbeat?term=1&commitIndex=0&leaderID=1")
resp, err := http.Get(s.URL + "/heartbeat?term=0&commitIndex=0&leaderID=0")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != "log closed" {
} else if s := resp.Header.Get("X-Raft-Error"); s != "log closed" {
t.Fatalf("unexpected raft error: %s", s)
}
}
// Ensure a stream can be retrieved over HTTP.
func TestHTTPHandler_HandleStream(t *testing.T) {
// TODO corylanou: racy failing test. Stack trace here: https://gist.github.com/corylanou/fc4e97afd31f793af426
t.Skip()
n := NewInitNode()
defer n.Close()
func TestHandler_HandleStream(t *testing.T) {
h := NewHandler()
h.WriteEntriesToFunc = func(w io.Writer, id, term, index uint64) error {
if w == nil {
t.Fatalf("expected writer")
} else if id != 1 {
t.Fatalf("unexpected id: %d", id)
} else if term != 2 {
t.Fatalf("unexpected term: %d", term)
}
w.Write([]byte("ok"))
return nil
}
s := httptest.NewServer(h)
defer s.Close()
// Connect to stream.
resp, err := http.Get(n.Server.URL + "/stream?id=1&term=1")
resp, err := http.Get(s.URL + "/stream?id=1&term=2")
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
@ -128,62 +144,24 @@ func TestHTTPHandler_HandleStream(t *testing.T) {
}
defer resp.Body.Close()
// Ensure the stream is connected before applying a command.
time.Sleep(10 * time.Millisecond)
// Add an entry.
if _, err := n.Log.Apply([]byte("xyz")); err != nil {
t.Fatal(err)
}
// Move log's clock ahead & flush data.
n.Log.Clock.Add(n.Log.HeartbeatInterval)
n.Log.Flush()
// Read entries from stream.
var e raft.LogEntry
dec := raft.NewLogEntryDecoder(resp.Body)
// First entry should be the configuration.
if err := dec.Decode(&e); err != nil {
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if e.Type != 0xFE {
t.Fatalf("expected configuration type: %d", e.Type)
}
// Next entry should be the snapshot.
if err := dec.Decode(&e); err != nil {
t.Fatalf("unexpected error: %s", err)
} else if !reflect.DeepEqual(&e, &raft.LogEntry{Type: 0xFF, Data: nil}) {
t.Fatalf("expected snapshot type: %d", e.Type)
}
// Read off the snapshot.
var fsm FSM
if err := fsm.Restore(resp.Body); err != nil {
t.Fatalf("restore: %s", err)
}
// Read off the snapshot index.
var index uint64
if err := binary.Read(resp.Body, binary.BigEndian, &index); err != nil {
t.Fatalf("read snapshot index: %s", err)
} else if index != 1 {
t.Fatalf("unexpected snapshot index: %d", index)
}
// Next entry should be the command.
if err := dec.Decode(&e); err != nil {
t.Fatalf("unexpected error: %s", err)
} else if !reflect.DeepEqual(&e, &raft.LogEntry{Index: 2, Term: 1, Data: []byte("xyz")}) {
t.Fatalf("unexpected entry: %#v", &e)
} else if string(b) != "ok" {
t.Fatalf("unexpected body: %s", b)
}
}
// Ensure that requesting a stream with an invalid term will return an error.
func TestHTTPHandler_HandleStream_Error(t *testing.T) {
// TODO corylanou: raft racy test. gist: https://gist.github.com/corylanou/aa4e75c4d873ea48fc90
t.Skip()
func TestHandler_HandleStream_Error(t *testing.T) {
h := NewHandler()
h.WriteEntriesToFunc = func(w io.Writer, id, term, index uint64) error {
return raft.ErrNotLeader
}
s := httptest.NewServer(h)
defer s.Close()
var tests = []struct {
query string
code int
@ -192,51 +170,62 @@ func TestHTTPHandler_HandleStream_Error(t *testing.T) {
{query: `id=1&term=XXX&index=0`, code: http.StatusBadRequest, err: `invalid term`},
{query: `id=1&term=1&index=XXX`, code: http.StatusBadRequest, err: `invalid index`},
{query: `id=XXX&term=1&index=XXX`, code: http.StatusBadRequest, err: `invalid id`},
{query: `id=1&term=2&index=0`, code: http.StatusInternalServerError, err: `not leader`},
{query: `id=0&term=1&index=2`, code: http.StatusInternalServerError, err: `not leader`},
}
for i, tt := range tests {
func() {
n := NewInitNode()
defer n.Close()
// Connect to stream.
resp, err := http.Get(n.Server.URL + "/stream?" + tt.query)
defer resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}()
resp, err := http.Get(s.URL + "/stream?" + tt.query)
resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}
}
// Ensure a vote request can be sent over HTTP.
func TestHTTPHandler_HandleRequestVote(t *testing.T) {
n := NewInitNode()
defer n.Close()
func TestHandler_HandleRequestVote(t *testing.T) {
h := NewHandler()
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
if term != 1 {
t.Fatalf("unexpected term: %d", term)
} else if candidateID != 2 {
t.Fatalf("unexpected candidate id: %d", candidateID)
} else if lastLogIndex != 3 {
t.Fatalf("unexpected last log index: %d", lastLogIndex)
} else if lastLogTerm != 4 {
t.Fatalf("unexpected last log term: %d", lastLogTerm)
}
return 5, nil
}
s := httptest.NewServer(h)
defer s.Close()
// Send vote request.
resp, err := http.Get(n.Server.URL + "/vote?term=5&candidateID=2&lastLogIndex=3&lastLogTerm=4")
resp, err := http.Get(s.URL + "/vote?term=1&candidateID=2&lastLogIndex=3&lastLogTerm=4")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != "" {
} else if s := resp.Header.Get("X-Raft-Error"); s != "" {
t.Fatalf("unexpected raft error: %s", s)
}
if s := resp.Header.Get("X-Raft-Term"); s != "1" {
} else if s = resp.Header.Get("X-Raft-Term"); s != "5" {
t.Fatalf("unexpected raft term: %s", s)
}
}
// Ensure sending invalid parameters in a vote request returns an error.
func TestHTTPHandler_HandleRequestVote_Error(t *testing.T) {
func TestHandler_HandleRequestVote_Error(t *testing.T) {
h := NewHandler()
h.RequestVoteFunc = func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
return 0, raft.ErrStaleTerm
}
s := httptest.NewServer(h)
defer s.Close()
var tests = []struct {
query string
code int
@ -249,32 +238,25 @@ func TestHTTPHandler_HandleRequestVote_Error(t *testing.T) {
{query: `term=0&candidateID=2&lastLogIndex=0&lastLogTerm=0`, code: http.StatusInternalServerError, err: `stale term`},
}
for i, tt := range tests {
func() {
n := NewInitNode()
defer n.Close()
// Send vote request.
resp, err := http.Get(n.Server.URL + "/vote?" + tt.query)
defer resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
}
if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}()
resp, err := http.Get(s.URL + "/vote?" + tt.query)
defer resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}
}
// Ensure an invalid path returns a 404.
func TestHTTPHandler_NotFound(t *testing.T) {
n := NewInitNode()
defer n.Close()
func TestHandler_NotFound(t *testing.T) {
s := httptest.NewServer(NewHandler())
defer s.Close()
// Send vote request.
resp, err := http.Get(n.Server.URL + "/aaaaahhhhh")
resp, err := http.Get(s.URL + "/aaaaahhhhh")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -282,3 +264,35 @@ func TestHTTPHandler_NotFound(t *testing.T) {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
}
// Handler represents a test wrapper for the raft.Handler.
type Handler struct {
*raft.Handler
AddPeerFunc func(u *url.URL) (uint64, *raft.Config, error)
RemovePeerFunc func(id uint64) error
HeartbeatFunc func(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error)
WriteEntriesToFunc func(w io.Writer, id, term, index uint64) error
RequestVoteFunc func(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
}
// NewHandler returns a new instance of Handler.
func NewHandler() *Handler {
h := &Handler{Handler: &raft.Handler{}}
h.Handler.Log = h
return h
}
func (h *Handler) AddPeer(u *url.URL) (uint64, *raft.Config, error) { return h.AddPeerFunc(u) }
func (h *Handler) RemovePeer(id uint64) error { return h.RemovePeerFunc(id) }
func (h *Handler) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, currentTerm uint64, err error) {
return h.HeartbeatFunc(term, commitIndex, leaderID)
}
func (h *Handler) WriteEntriesTo(w io.Writer, id, term, index uint64) error {
return h.WriteEntriesToFunc(w, id, term, index)
}
func (h *Handler) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
return h.RequestVoteFunc(term, candidateID, lastLogIndex, lastLogTerm)
}

5
raft/internal_test.go Normal file
View File

@ -0,0 +1,5 @@
package raft
func (l *Log) WaitUncommitted(index uint64) error { return l.waitUncommitted(index) }
func (l *Log) WaitCommitted(index uint64) error { return l.waitCommitted(index) }
func (l *Log) WaitApplied(index uint64) error { return l.Wait(index) }

View File

@ -22,20 +22,6 @@ import (
"time"
)
const (
// DefaultHeartbeatInterval is the default time to wait between heartbeats.
DefaultHeartbeatInterval = 150 * time.Millisecond
// DefaultElectionTimeout is the default time before starting an election.
DefaultElectionTimeout = 500 * time.Millisecond
// DefaultReconnectTimeout is the default time to wait before reconnecting.
DefaultReconnectTimeout = 10 * time.Millisecond
// DefaultApplyInterval is the default time between checks to apply commands.
DefaultApplyInterval = 10 * time.Millisecond
)
// FSM represents the state machine that the log is applied to.
// The FSM must maintain the highest index that it has seen.
type FSM interface {
@ -115,47 +101,42 @@ type Log struct {
FSM FSM
// The transport used to communicate with other nodes in the cluster.
// If nil, then the DefaultTransport is used.
Transport Transport
Transport interface {
Join(u *url.URL, nodeURL *url.URL) (uint64, *Config, error)
Leave(u *url.URL, id uint64) error
Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (lastIndex, currentTerm uint64, err error)
ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser, error)
RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
}
// The amount of time between Append Entries RPC calls from the leader to
// its followers.
HeartbeatInterval time.Duration
// The amount of time before a follower attempts an election.
ElectionTimeout time.Duration
// The amount of time between stream reconnection attempts.
ReconnectTimeout time.Duration
// The amount of time that the log will wait between applying outstanding
// committed log entries. A lower interval will reduce latency but a higher
// interval will batch more commands together and improve throughput.
ApplyInterval time.Duration
// Clock is an abstraction of the time package. By default it will use
// a real-time clock but a mock clock can be used for testing.
Clock Clock
// Clock is an abstraction of time.
Clock interface {
Now() time.Time
AfterApplyInterval() <-chan chan struct{}
AfterElectionTimeout() <-chan chan struct{}
AfterHeartbeatInterval() <-chan chan struct{}
AfterReconnectTimeout() <-chan chan struct{}
}
// Rand returns a random number.
Rand func() int64
// Sets whether trace messages are logged.
DebugEnabled bool
// This logs some asynchronous errors that occur within the log.
Logger *log.Logger
}
// NewLog creates a new instance of Log with reasonable defaults.
func NewLog() *Log {
return &Log{
Clock: &clock{},
Rand: rand.Int63,
Logger: log.New(os.Stderr, "[raft] ", log.LstdFlags),
HeartbeatInterval: DefaultHeartbeatInterval,
ElectionTimeout: DefaultElectionTimeout,
ReconnectTimeout: DefaultReconnectTimeout,
ApplyInterval: DefaultApplyInterval,
l := &Log{
Clock: NewClock(),
Transport: &HTTPTransport{},
Rand: rand.Int63,
}
l.SetLogOutput(os.Stderr)
return l
}
// Path returns the data path of the Raft log.
@ -206,11 +187,6 @@ func (l *Log) Config() *Config {
return nil
}
// SetLogOutput sets writer for all Raft output.
func (l *Log) SetLogOutput(w io.Writer) {
l.Logger = log.New(w, "[raft] ", log.LstdFlags)
}
// Open initializes the log from a path.
// If the path does not exist then it is created.
func (l *Log) Open(path string) error {
@ -234,7 +210,7 @@ func (l *Log) Open(path string) error {
_ = l.close()
return err
}
l.id = id
l.setID(id)
// Initialize log term.
term, err := l.readTerm()
@ -257,6 +233,7 @@ func (l *Log) Open(path string) error {
if err != nil {
return err
}
l.tracef("Open: fsm: index=%d", index)
l.index = index
l.appliedIndex = index
l.commitIndex = index
@ -322,8 +299,10 @@ func (l *Log) close() error {
}
l.writers = nil
l.tracef("close")
// Clear log info.
l.id = 0
l.setID(0)
l.path = ""
l.index, l.term = 0, 0
l.config = nil
@ -331,6 +310,11 @@ func (l *Log) close() error {
return nil
}
func (l *Log) setID(id uint64) {
l.id = id
l.updateLogPrefix()
}
// readID reads the log identifier from file.
func (l *Log) readID() (uint64, error) {
// Read identifier from disk.
@ -420,14 +404,6 @@ func (l *Log) writeConfig(config *Config) error {
return nil
}
// transport returns the log's transport or the default transport.
func (l *Log) transport() Transport {
if t := l.Transport; t != nil {
return t
}
return DefaultTransport
}
// Initialize a new log.
// Returns an error if log data already exists.
func (l *Log) Initialize() error {
@ -457,7 +433,7 @@ func (l *Log) Initialize() error {
if err := l.writeID(id); err != nil {
return err
}
l.id = id
l.setID(id)
// Automatically promote to leader.
term := uint64(1)
@ -488,6 +464,34 @@ func (l *Log) Initialize() error {
return l.Wait(index)
}
// SetLogOutput sets writer for all Raft output.
func (l *Log) SetLogOutput(w io.Writer) {
l.Logger = log.New(w, "", log.LstdFlags)
l.updateLogPrefix()
}
func (l *Log) updateLogPrefix() {
var host string
if l.URL != nil {
host = l.URL.Host
}
l.Logger.SetPrefix(fmt.Sprintf("[raft] %s ", host))
}
// trace writes a log message if DebugEnabled is true.
func (l *Log) trace(v ...interface{}) {
if l.DebugEnabled {
l.Logger.Print(v...)
}
}
// trace writes a formatted log message if DebugEnabled is true.
func (l *Log) tracef(msg string, v ...interface{}) {
if l.DebugEnabled {
l.Logger.Printf(msg+"\n", v...)
}
}
// Leader returns the id and URL associated with the current leader.
// Returns zero if there is no current leader.
func (l *Log) Leader() (id uint64, u *url.URL) {
@ -514,29 +518,45 @@ func (l *Log) leader() (id uint64, u *url.URL) {
// Join contacts a node in the cluster to request membership.
// A log cannot join a cluster if it has already been initialized.
func (l *Log) Join(u *url.URL) error {
l.mu.Lock()
defer l.mu.Unlock()
// Validate under lock.
var nodeURL *url.URL
if err := func() error {
l.mu.Lock()
defer l.mu.Unlock()
// Check if open.
if !l.opened() {
return ErrClosed
} else if l.id != 0 {
return ErrInitialized
} else if l.URL == nil {
return ErrURLRequired
if !l.opened() {
return ErrClosed
} else if l.id != 0 {
return ErrInitialized
} else if l.URL == nil {
return ErrURLRequired
}
nodeURL = l.URL
return nil
}(); err != nil {
return err
}
l.tracef("Join: %s", u)
// Send join request.
id, config, err := l.transport().Join(u, l.URL)
id, config, err := l.Transport.Join(u, nodeURL)
if err != nil {
return err
}
l.tracef("Join: confirmed")
// Lock once the join request is returned.
l.mu.Lock()
defer l.mu.Unlock()
// Write identifier.
if err := l.writeID(id); err != nil {
return err
}
l.id = id
l.setID(id)
// Write config.
if err := l.writeConfig(config); err != nil {
@ -597,6 +617,7 @@ func (l *Log) setState(state State) {
// followerLoop continually attempts to stream the log from the current leader.
func (l *Log) followerLoop(done chan struct{}) {
l.tracef("followerLoop")
var rch chan struct{}
for {
// Retrieve the term, last index, & leader URL.
@ -611,12 +632,14 @@ func (l *Log) followerLoop(done chan struct{}) {
// If no leader exists then wait momentarily and retry.
if u == nil {
l.tracef("followerLoop: no leader")
time.Sleep(1 * time.Millisecond)
continue
}
// Connect to leader.
r, err := l.transport().ReadFrom(u, id, term, index)
l.tracef("followerLoop: read from: %s, id=%d, term=%d, index=%d", u.String(), id, term, index)
r, err := l.Transport.ReadFrom(u, id, term, index)
if err != nil {
l.Logger.Printf("connect stream: %s", err)
}
@ -663,7 +686,7 @@ func (l *Log) elect(done chan struct{}) {
for _, n := range config.Nodes {
if n.ID != id {
go func(n *ConfigNode) {
peerTerm, err := l.transport().RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm)
peerTerm, err := l.Transport.RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm)
if err != nil {
l.Logger.Printf("request vote: %s", err)
return
@ -677,14 +700,15 @@ func (l *Log) elect(done chan struct{}) {
}
// Wait for respones or timeout.
after := l.Clock.After(l.ElectionTimeout)
after := l.Clock.AfterElectionTimeout()
voteN := 1
loop:
for {
select {
case <-done:
return
case <-after:
case ch := <-after:
defer close(ch)
break loop
case <-ch:
voteN++
@ -713,24 +737,30 @@ loop:
// leaderLoop periodically sends heartbeats to all followers to maintain dominance.
func (l *Log) leaderLoop(done chan struct{}) {
ticker := l.Clock.Ticker(l.HeartbeatInterval)
defer ticker.Stop()
l.tracef("leaderLoop: start")
confirm := make(chan struct{}, 0)
for {
// Send hearbeat to followers.
if err := l.sendHeartbeat(done); err != nil {
close(confirm)
return
}
// Signal clock that the heartbeat has occurred.
close(confirm)
select {
case <-done: // wait for state change.
return
case <-ticker.C: // wait for next heartbeat
case confirm = <-l.Clock.AfterHeartbeatInterval(): // wait for next heartbeat
}
}
}
// sendHeartbeat sends heartbeats to all the nodes.
func (l *Log) sendHeartbeat(done chan struct{}) error {
l.tracef("sendHeartbeat")
// Retrieve config and term.
l.mu.Lock()
if err := check(done); err != nil {
@ -744,6 +774,7 @@ func (l *Log) sendHeartbeat(done chan struct{}) error {
// Ignore if there is no config or nodes yet.
if config == nil || len(config.Nodes) <= 1 {
l.tracef("sendHeartbeat: no peers")
return nil
}
@ -755,12 +786,14 @@ func (l *Log) sendHeartbeat(done chan struct{}) error {
for _, n := range config.Nodes {
if n.ID != l.id {
go func(n *ConfigNode) {
peerIndex, peerTerm, err := l.transport().Heartbeat(n.URL, term, commitIndex, leaderID)
l.tracef("sendHeartbeat: url=%s, term=%d, commit=%d, leaderID=%d", n.URL, term, commitIndex, leaderID)
peerIndex, peerTerm, err := l.Transport.Heartbeat(n.URL, term, commitIndex, leaderID)
if err != nil {
l.Logger.Printf("heartbeat: %s", err)
return
} else if peerTerm > term {
// TODO(benbjohnson): Step down.
l.tracef("sendHeartbeat: TODO step down: peer=%d, term=%d", peerTerm, term)
return
}
ch <- peerIndex
@ -769,7 +802,7 @@ func (l *Log) sendHeartbeat(done chan struct{}) error {
}
// Wait for heartbeat responses or timeout.
after := l.Clock.After(l.HeartbeatInterval)
after := l.Clock.AfterHeartbeatInterval()
indexes := make([]uint64, 1, nodeN)
indexes[0] = localIndex
loop:
@ -777,11 +810,14 @@ loop:
select {
case <-done:
return errDone
case <-after:
case ch := <-after:
defer close(ch)
l.tracef("sendHeartbeat: timeout")
break loop
case index := <-ch:
indexes = append(indexes, index)
if len(indexes) == nodeN {
l.tracef("sendHeartbeat: received heartbeats")
break loop
}
}
@ -791,6 +827,7 @@ loop:
// We don't add the +1 because the slice starts from 0.
quorumIndex := (nodeN / 2)
if quorumIndex >= len(indexes) {
l.tracef("sendHeartbeat: no quorum: n=%d", quorumIndex)
return nil
}
@ -805,6 +842,7 @@ loop:
return err
}
if newCommitIndex > l.commitIndex {
l.tracef("sending heartbeat: commit index %d => %d", l.commitIndex, newCommitIndex)
l.commitIndex = newCommitIndex
}
l.mu.Unlock()
@ -875,12 +913,45 @@ func (l *Log) Wait(index uint64) error {
} else if appliedIndex >= index {
return nil
}
l.Clock.Sleep(l.ApplyInterval)
time.Sleep(1 * time.Millisecond)
}
}
// waitCommitted blocks until a given committed index is reached.
func (l *Log) waitCommitted(index uint64) error {
for {
l.mu.Lock()
state, committedIndex := l.state, l.commitIndex
l.mu.Unlock()
if state == Stopped {
return ErrClosed
} else if committedIndex >= index {
return nil
}
time.Sleep(1 * time.Millisecond)
}
}
// waitUncommitted blocks until a given uncommitted index is reached.
func (l *Log) waitUncommitted(index uint64) error {
for {
l.mu.Lock()
state, uncommittedIndex := l.state, l.index
l.mu.Unlock()
if state == Stopped {
return ErrClosed
} else if uncommittedIndex >= index {
return nil
}
time.Sleep(1 * time.Millisecond)
}
}
// append adds a log entry to the list of entries.
func (l *Log) append(e *LogEntry) {
l.tracef("append: idx=%d, prev=%d", e.Index, l.index)
assert(e.Index == l.index+1, "non-contiguous log index(%d): idx=%d, prev=%d", l.id, e.Index, l.index)
// Encode entry to a byte slice.
@ -913,14 +984,17 @@ func (l *Log) append(e *LogEntry) {
func (l *Log) applier(done chan chan struct{}) {
for {
// Wait for a close signal or timeout.
var confirm chan struct{}
select {
case ch := <-done:
close(ch)
return
case <-l.Clock.After(l.ApplyInterval):
case confirm = <-l.Clock.AfterApplyInterval():
}
l.tracef("applier")
// Apply all entries committed since the previous apply.
err := func() error {
l.mu.Lock()
@ -937,8 +1011,10 @@ func (l *Log) applier(done chan chan struct{}) {
// Ignore if there are no pending entries.
// Ignore if all entries are applied.
if len(l.entries) == 0 {
l.tracef("applier: no entries")
return nil
} else if l.appliedIndex == l.commitIndex {
l.tracef("applier: up to date")
return nil
}
@ -963,6 +1039,8 @@ func (l *Log) applier(done chan chan struct{}) {
// Iterate over each entry and apply it.
for _, e := range entries {
l.tracef("applier: entry: idx=%d", e.Index)
switch e.Type {
case LogEntryCommand, LogEntryNop:
case LogEntryInitialize:
@ -990,12 +1068,15 @@ func (l *Log) applier(done chan chan struct{}) {
// If error occurred then log it.
// The log will retry after a given timeout.
if err == errDone {
close(confirm)
return
} else if err != nil {
l.Logger.Printf("apply: %s", err)
l.Logger.Printf("apply error: %s", err)
// TODO(benbjohnson): Longer timeout before retry?
continue
}
// Signal clock that apply is done.
close(confirm)
}
}
@ -1105,18 +1186,23 @@ func (l *Log) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, curre
l.mu.Lock()
defer l.mu.Unlock()
l.tracef("Heartbeat: term=%d, commit=%d, leaderID: %d", term, commitIndex, leaderID)
// Check if log is closed.
if !l.opened() {
l.tracef("Heartbeat: closed")
return 0, 0, ErrClosed
}
// Ignore if the incoming term is less than the log's term.
if term < l.term {
l.tracef("Heartbeat: stale term, ignore")
return l.index, l.term, nil
}
// Step down if we see a higher term.
if term > l.term {
l.tracef("Heartbeat: higher term, stepping down")
l.term = term
l.setState(Follower)
}
@ -1164,12 +1250,14 @@ func (l *Log) RequestVote(term, candidateID, lastLogIndex, lastLogTerm uint64) (
func (l *Log) elector(done chan chan struct{}) {
for {
// Wait for a close signal or election timeout.
var confirm chan struct{}
select {
case ch := <-done:
close(ch)
return
case <-l.Clock.After(l.ElectionTimeout): // TODO(election): Randomize
case confirm = <-l.Clock.AfterElectionTimeout(): // TODO(election): Randomize
}
l.tracef("elector")
// If log is a follower or candidate and an election timeout has passed
// since a contact from a heartbeat then start a new election.
@ -1177,7 +1265,6 @@ func (l *Log) elector(done chan chan struct{}) {
l.mu.Lock()
defer l.mu.Unlock()
// Verify that we're not closing.
// Verify, under lock, that we're not closing.
select {
case ch := <-done:
@ -1189,11 +1276,18 @@ func (l *Log) elector(done chan chan struct{}) {
// Ignore if not a follower or a candidate.
// Ignore if the last contact was less than the election timeout.
if l.state != Follower && l.state != Candidate {
l.tracef("elector: log is not follower or candidate")
return nil
} else if l.lastContact.IsZero() || l.Clock.Now().Sub(l.lastContact) < l.ElectionTimeout {
} else if l.lastContact.IsZero() {
l.tracef("elector: last contact is zero")
return nil
} else if l.Clock.Now().Sub(l.lastContact) < DefaultElectionTimeout { // TODO: Refactor into follower loop and candidate loop.
l.tracef("elector: last contact is less than election timeout")
return nil
}
l.tracef("elector: beginning election in term %d", l.term+1)
// Otherwise start a new election and promote.
term := l.term + 1
if err := l.writeTerm(term); err != nil {
@ -1207,10 +1301,14 @@ func (l *Log) elector(done chan chan struct{}) {
// Check if we exited because we're closing.
if err == errDone {
close(confirm)
return
} else if err != nil {
panic("unreachable")
}
// Signal clock that elector is done.
close(confirm)
}
}
@ -1391,6 +1489,8 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
// Continually decode entries.
dec := NewLogEntryDecoder(r)
for {
l.tracef("ReadFrom")
// Decode single entry.
var e LogEntry
if err := dec.Decode(&e); err == io.EOF {
@ -1401,6 +1501,8 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
// If this is a config entry then update the config.
if e.Type == logEntryConfig {
l.tracef("ReadFrom: config")
config := &Config{}
if err := NewConfigDecoder(bytes.NewReader(e.Data)).Decode(config); err != nil {
return err
@ -1418,15 +1520,19 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
// If this is a snapshot then load it.
if e.Type == logEntrySnapshot {
l.tracef("ReadFrom: snapshot")
if err := l.FSM.Restore(r); err != nil {
return err
}
l.tracef("ReadFrom: snapshot: restored")
// Read the snapshot index off the end of the snapshot.
var index uint64
if err := binary.Read(r, binary.BigEndian, &index); err != nil {
return fmt.Errorf("read snapshot index: %s", err)
}
l.tracef("ReadFrom: snapshot: index=%d", index)
// Update the indicies.
l.index = index
@ -1440,7 +1546,14 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
}
// Append entry to the log.
l.mu.Lock()
if l.state == Stopped {
l.mu.Unlock()
return nil
}
l.tracef("ReadFrom: entry: index=%d / prev=%d / commit=%d", e.Index, l.index, l.commitIndex)
l.append(&e)
l.mu.Unlock()
}
}

View File

@ -7,224 +7,226 @@ import (
"io"
"io/ioutil"
"log"
"net/http/httptest"
"net/url"
"os"
"sync"
"testing"
"time"
"github.com/influxdb/influxdb/raft"
)
// Ensure that opening an already open log returns an error.
func TestLog_Open_ErrOpen(t *testing.T) {
n := NewInitNode()
defer n.Close()
if err := n.Log.Open(tempfile()); err != raft.ErrOpen {
l := NewInitializedLog(&url.URL{Host: "log0"})
defer l.Close()
if err := l.Open(tempfile()); err != raft.ErrOpen {
t.Fatal("expected error")
}
}
// Ensure that a log can be checked for being open.
func TestLog_Opened(t *testing.T) {
n := NewInitNode()
if n.Log.Opened() != true {
l := NewInitializedLog(&url.URL{Host: "log0"})
if l.Opened() != true {
t.Fatalf("expected open")
}
n.Close()
if n.Log.Opened() != false {
l.Close()
if l.Opened() != false {
t.Fatalf("expected closed")
}
}
// Ensure that reopening an existing log will restore its ID.
func TestLog_Reopen(t *testing.T) {
n := NewInitNode()
if n.Log.ID() != 1 {
l := NewInitializedLog(&url.URL{Host: "log0"})
if l.ID() != 1 {
t.Fatalf("expected id == 1")
}
path := n.Log.Path()
path := l.Path()
// Close log and make sure id is cleared.
n.Close()
if n.Log.ID() != 0 {
l.Log.Close()
if l.ID() != 0 {
t.Fatalf("expected id == 0")
}
// Re-open and ensure id is restored.
if err := n.Log.Open(path); err != nil {
if err := l.Open(path); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n.Log.ID() != 1 {
t.Fatalf("expected id == 1")
if id := l.ID(); id != 1 {
t.Fatalf("unexpected id: %d", id)
}
n.Close()
l.Close()
}
// Ensure that a single node-cluster can apply a log entry.
func TestLog_Apply(t *testing.T) {
// TODO corylanou: this test is intermittently failing. Fix and re-enable
// trace can be found here for failing test: https://gist.github.com/corylanou/1bb0a5d11447177e478f
t.Skip()
n := NewInitNode()
defer n.Close()
l := NewInitializedLog(&url.URL{Host: "log0"})
defer l.Close()
// Apply a command.
index, err := n.Log.Apply([]byte("foo"))
index, err := l.Apply([]byte("foo"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if index != 2 {
t.Fatalf("unexpected index: %d", index)
}
// Force apply cycle and then signal wait.
go func() { l.Clock.apply() }()
// Single node clusters should apply to FSM immediately.
n.Log.Wait(index)
if n := len(n.FSM().Commands); n != 1 {
l.Wait(index)
if n := len(l.FSM.Commands); n != 1 {
t.Fatalf("unexpected command count: %d", n)
}
}
// Ensure that log ids are set sequentially.
func TestLog_ID_Sequential(t *testing.T) {
c := NewCluster(3)
defer c.Close()
for i, n := range c.Nodes {
if n.Log.ID() != uint64(i+1) {
t.Fatalf("expected id: %d, got: %d", i+1, n.Log.ID())
}
}
}
// Ensure that cluster starts with one leader and multiple followers.
func TestLog_State(t *testing.T) {
c := NewCluster(3)
defer c.Close()
if state := c.Nodes[0].Log.State(); state != raft.Leader {
t.Fatalf("unexpected state(0): %s", state)
}
if state := c.Nodes[1].Log.State(); state != raft.Follower {
t.Fatalf("unexpected state(1): %s", state)
}
if state := c.Nodes[2].Log.State(); state != raft.Follower {
t.Fatalf("unexpected state(2): %s", state)
}
}
// Ensure that a node has no configuration after it's closed.
func TestLog_Config_Closed(t *testing.T) {
// TODO corylanou: racy test: gist: https://gist.github.com/corylanou/965ccf919e965082c338
t.Skip()
n := NewInitNode()
n.Close()
if n.Log.Config() != nil {
l := NewInitializedLog(&url.URL{Host: "log0"})
defer l.Close()
l.Log.Close()
if l.Config() != nil {
t.Fatal("expected nil config")
}
}
// Ensure that each node's configuration matches in the cluster.
func TestLog_Config(t *testing.T) {
c := NewCluster(3)
// Ensure that log ids in a cluster are set sequentially.
func TestCluster_ID_Sequential(t *testing.T) {
c := NewCluster()
defer c.Close()
config := jsonify(c.Nodes[0].Log.Config())
for _, n := range c.Nodes[1:] {
if b := jsonify(n.Log.Config()); config != b {
t.Fatalf("config mismatch(%d):\n\nexp=%s\n\ngot:%s\n\n", n.Log.ID(), config, b)
for i, l := range c.Logs {
if l.ID() != uint64(i+1) {
t.Fatalf("expected id: %d, got: %d", i+1, l.ID())
}
}
}
// Ensure that a new log can be successfully opened and closed.
func TestLog_Apply_Cluster(t *testing.T) {
// TODO corylanou racy test. gist: https://gist.github.com/corylanou/00d99de1ed9e02873196
t.Skip()
c := NewCluster(3)
// Ensure that cluster starts with one leader and multiple followers.
func TestCluster_State(t *testing.T) {
c := NewCluster()
defer c.Close()
if state := c.Logs[0].State(); state != raft.Leader {
t.Fatalf("unexpected state(0): %s", state)
}
if state := c.Logs[1].State(); state != raft.Follower {
t.Fatalf("unexpected state(1): %s", state)
}
if state := c.Logs[2].State(); state != raft.Follower {
t.Fatalf("unexpected state(2): %s", state)
}
}
// Ensure that each node's configuration matches in the cluster.
func TestCluster_Config(t *testing.T) {
c := NewCluster()
defer c.Close()
config := jsonify(c.Logs[0].Config())
for _, l := range c.Logs[1:] {
if b := jsonify(l.Config()); config != b {
t.Fatalf("config mismatch(%d):\n\nexp=%s\n\ngot:%s\n\n", l.ID(), config, b)
}
}
}
// Ensure that a command can be applied to a cluster and distributed appropriately.
func TestCluster_Apply(t *testing.T) {
c := NewCluster()
defer c.Close()
// Apply a command.
leader := c.Nodes[0]
index, err := leader.Log.Apply([]byte("foo"))
leader := c.Logs[0]
index, err := leader.Apply([]byte("foo"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if index != 4 {
t.Fatalf("unexpected index: %d", index)
}
leader.Log.Flush()
c.Logs[1].MustWaitUncommitted(4)
c.Logs[2].MustWaitUncommitted(4)
// Should not apply immediately.
if n := len(leader.FSM().Commands); n != 0 {
if n := len(leader.FSM.Commands); n != 0 {
t.Fatalf("unexpected pre-heartbeat command count: %d", n)
}
// Wait for a heartbeat and let the log apply the changes.
// Run the heartbeat on the leader and have all logs apply.
// Only the leader should have the changes applied.
c.Clock().Add(leader.Log.HeartbeatInterval)
if n := len(c.Nodes[0].FSM().Commands); n != 1 {
c.Logs[0].Clock.heartbeat()
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
if n := len(c.Logs[0].FSM.Commands); n != 1 {
t.Fatalf("unexpected command count(0): %d", n)
}
if n := len(c.Nodes[1].FSM().Commands); n != 0 {
if n := len(c.Logs[1].FSM.Commands); n != 0 {
t.Fatalf("unexpected command count(1): %d", n)
}
if n := len(c.Nodes[2].FSM().Commands); n != 0 {
if n := len(c.Logs[2].FSM.Commands); n != 0 {
t.Fatalf("unexpected command count(2): %d", n)
}
// Wait for another heartbeat and all nodes should be in sync.
c.Clock().Add(leader.Log.HeartbeatInterval)
if n := len(c.Nodes[1].FSM().Commands); n != 1 {
// Wait for another heartbeat and all logs should be in sync.
c.Logs[0].Clock.heartbeat()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
if n := len(c.Logs[1].FSM.Commands); n != 1 {
t.Fatalf("unexpected command count(1): %d", n)
}
if n := len(c.Nodes[2].FSM().Commands); n != 1 {
if n := len(c.Logs[2].FSM.Commands); n != 1 {
t.Fatalf("unexpected command count(2): %d", n)
}
}
// Ensure that a new leader can be elected.
func TestLog_Elect(t *testing.T) {
// TODO: corylanou: racy test. gist: https://gist.github.com/corylanou/2a354673bd863a7c0770
t.Skip()
c := NewCluster(3)
c := NewCluster()
defer c.Close()
n0, n1, n2 := c.Nodes[0], c.Nodes[1], c.Nodes[2]
// Stop leader.
path := n0.Log.Path()
n0.Log.Close()
path := c.Logs[0].Path()
c.Logs[0].Log.Close()
// Wait for election timeout.
c.Clock().Add(2 * n0.Log.ElectionTimeout)
// Signal election on node 1. Then heartbeat to establish leadership.
c.Logs[1].Clock.now = c.Logs[1].Clock.now.Add(raft.DefaultElectionTimeout)
c.Logs[1].Clock.election()
c.Logs[1].Clock.heartbeat()
// Ensure one node is elected in the next term.
if s1, s2 := n1.Log.State(), n2.Log.State(); s1 != raft.Leader && s2 != raft.Leader {
t.Fatalf("expected leader: n1=%s, n2=%s", s1, s2)
}
leader := c.Leader()
if term := leader.Log.Term(); term != 2 {
t.Fatalf("unexpected new term: %d", term)
// Ensure node 1 is elected in the next term.
if state := c.Logs[1].State(); state != raft.Leader {
t.Fatalf("expected node 1 to move to leader: %s", state)
} else if term := c.Logs[1].Term(); term != 2 {
t.Fatalf("expected term 2: got %d", term)
}
// Restart leader and make sure it rejoins as a follower.
if err := n0.Log.Open(path); err != nil {
if err := c.Logs[0].Open(path); err != nil {
t.Fatalf("unexpected open error: %s", err)
}
// Wait for a heartbeat and verify the new leader is still the leader.
c.Clock().Add(leader.Log.HeartbeatInterval)
if state := leader.Log.State(); state != raft.Leader {
t.Fatalf("new leader deposed: %s", state)
}
if term := n0.Log.Term(); term != 2 {
t.Fatalf("invalid term: %d", term)
// Wait for a heartbeat and verify the node 1 is still the leader.
c.Logs[1].Clock.heartbeat()
if state := c.Logs[1].State(); state != raft.Leader {
t.Fatalf("node 1 unexpectedly deposed: %s", state)
} else if term := c.Logs[1].Term(); term != 2 {
t.Fatalf("expected node 0 to go to term 2: got term %d", term)
}
// Apply a command and ensure it's replicated.
index, err := leader.Log.Apply([]byte("abc"))
index, err := c.Logs[1].Log.Apply([]byte("abc"))
if err != nil {
t.Fatalf("unexpected apply error: %s", err)
}
leader.Log.Flush()
go func() { c.Clock().Add(2 * leader.Log.HeartbeatInterval) }()
if err := leader.Log.Wait(index); err != nil {
c.MustWaitUncommitted(index)
c.Logs[1].Clock.heartbeat()
c.Logs[1].Clock.heartbeat()
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
if err := c.Logs[0].Wait(index); err != nil {
t.Fatalf("unexpected wait error: %s", err)
}
}
@ -250,137 +252,150 @@ func TestState_String(t *testing.T) {
// Cluster represents a collection of nodes that share the same mock clock.
type Cluster struct {
Nodes []*Node
Logs []*Log
}
// NewCluster creates a new cluster with an initial set of nodes.
func NewCluster(nodeN int) *Cluster {
// NewCluster creates a new 3 log cluster.
func NewCluster() *Cluster {
c := &Cluster{}
for i := 0; i < nodeN; i++ {
n := c.NewNode()
n.Open()
t := NewTransport()
// Initialize the first node.
// Join remaining nodes to the first node.
if i == 0 {
go func() { n.Clock().Add(2 * n.Log.ApplyInterval) }()
if err := n.Log.Initialize(); err != nil {
panic("initialize: " + err.Error())
}
} else {
go func() { n.Clock().Add(n.Log.HeartbeatInterval) }()
if err := n.Log.Join(c.Nodes[0].Log.URL); err != nil {
panic("join: " + err.Error())
}
}
logN := 3
for i := 0; i < logN; i++ {
l := NewLog(&url.URL{Host: fmt.Sprintf("log%d", i)})
l.Transport = t
c.Logs = append(c.Logs, l)
t.register(l.Log)
warnf("Log %s: %p", l.URL.String(), l.Log)
}
warn("")
// Initialize leader.
c.Logs[0].MustOpen()
c.Logs[0].MustInitialize()
// Join second node.
go func() {
c.Logs[0].MustWaitUncommitted(2)
c.Logs[0].Clock.apply()
c.Logs[0].Clock.heartbeat()
c.Logs[1].Clock.apply()
}()
c.Logs[1].MustOpen()
if err := c.Logs[1].Join(c.Logs[0].URL); err != nil {
panic("join: " + err.Error())
}
// Make sure everything is replicated to all followers.
c.Nodes[0].Log.Flush()
c.Clock().Add(c.Nodes[0].Log.HeartbeatInterval)
// Join third node.
go func() {
c.Logs[0].MustWaitUncommitted(3)
c.Logs[0].Clock.heartbeat()
c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
}()
c.Logs[2].MustOpen()
if err := c.Logs[2].Log.Join(c.Logs[0].Log.URL); err != nil {
panic("join: " + err.Error())
}
// Heartbeart final commit index to all nodes and reapply.
c.Logs[0].Clock.heartbeat()
c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply()
return c
}
// Close closes all nodes in the cluster.
// Close closes all logs in the cluster.
func (c *Cluster) Close() {
for _, n := range c.Nodes {
n.Close()
for _, l := range c.Logs {
l.Close()
}
}
// NewNode creates a new node on the cluster with the same clock.
func (c *Cluster) NewNode() *Node {
n := NewNode()
if len(c.Nodes) > 0 {
n.Log.Clock = c.Nodes[0].Clock()
}
c.Nodes = append(c.Nodes, n)
return n
}
// Clock returns the a clock that will slightly delay clock movement.
func (c *Cluster) Clock() raft.Clock { return &delayClock{c.Nodes[0].Log.Clock} }
// Leader returns the leader node with the highest term.
func (c *Cluster) Leader() *Node {
var leader *Node
for _, n := range c.Nodes {
if n.Log.State() == raft.Leader && (leader == nil || leader.Log.Term() < n.Log.Term()) {
leader = n
// Leader returns the leader log with the highest term.
func (c *Cluster) Leader() *Log {
var leader *Log
for _, l := range c.Logs {
if l.State() == raft.Leader && (leader == nil || leader.Log.Term() < l.Term()) {
leader = l
}
}
return leader
}
// Node represents a log, FSM and associated HTTP server.
type Node struct {
Log *raft.Log
Server *httptest.Server
// WaitUncommitted waits until all logs in the cluster have reached a given uncomiitted index.
func (c *Cluster) MustWaitUncommitted(index uint64) {
for _, l := range c.Logs {
l.MustWaitUncommitted(index)
}
}
// NewNode returns a new instance of Node.
func NewNode() *Node {
n := &Node{Log: raft.NewLog()}
n.Log.FSM = &FSM{}
n.Log.Clock = raft.NewMockClock()
n.Log.Rand = seq()
// flush issues messages to cycle all logs.
func (c *Cluster) flush() {
for _, l := range c.Logs {
l.Clock.heartbeat()
l.Clock.apply()
}
}
// Log represents a test log.
type Log struct {
*raft.Log
Clock *Clock
FSM *FSM
}
// NewLog returns a new instance of Log.
func NewLog(u *url.URL) *Log {
l := &Log{Log: raft.NewLog(), Clock: NewClock(), FSM: &FSM{}}
l.URL = u
l.Log.FSM = l.FSM
l.Log.Clock = l.Clock
l.Rand = seq()
l.DebugEnabled = true
if !testing.Verbose() {
n.Log.Logger = log.New(ioutil.Discard, "", 0)
l.Logger = log.New(ioutil.Discard, "", 0)
}
return n
return l
}
// NewInitNode returns a new initialized Node.
func NewInitNode() *Node {
n := NewNode()
n.Open()
go func() { n.Clock().Add(3 * n.Log.ApplyInterval) }()
if err := n.Log.Initialize(); err != nil {
panic("initialize: " + err.Error())
}
return n
// NewInitializedLog returns a new initialized Node.
func NewInitializedLog(u *url.URL) *Log {
l := NewLog(u)
l.MustOpen()
l.MustInitialize()
return l
}
// Open opens the log and HTTP server.
func (n *Node) Open() {
// Start the HTTP server.
n.Server = httptest.NewServer(raft.NewHTTPHandler(n.Log))
n.Log.URL, _ = url.Parse(n.Server.URL)
// Open the log.
if err := n.Log.Open(tempfile()); err != nil {
// MustOpen opens the log. Panic on error.
func (l *Log) MustOpen() {
if err := l.Open(tempfile()); err != nil {
panic("open: " + err.Error())
}
}
// Close closes the log and HTTP server.
func (n *Node) Close() error {
defer func() { _ = os.RemoveAll(n.Log.Path()) }()
_ = n.Log.Close()
if n.Server != nil {
n.Server.CloseClientConnections()
n.Server.Close()
n.Server = nil
// MustInitialize initializes the log. Panic on error.
func (l *Log) MustInitialize() {
go func() { l.Clock.apply() }()
if err := l.Initialize(); err != nil {
panic("initialize: " + err.Error())
}
}
// Close closes the log and HTTP server.
func (l *Log) Close() error {
defer os.RemoveAll(l.Log.Path())
_ = l.Log.Close()
return nil
}
// Clock returns the a clock that will slightly delay clock movement.
func (n *Node) Clock() raft.Clock { return &delayClock{n.Log.Clock} }
// FSM returns the state machine.
func (n *Node) FSM() *FSM { return n.Log.FSM.(*FSM) }
// delayClock represents a clock that adds a slight delay on clock movement.
// This ensures that clock movement doesn't occur too quickly.
type delayClock struct {
raft.Clock
}
func (c *delayClock) Add(d time.Duration) {
time.Sleep(10 * time.Millisecond)
c.Clock.Add(d)
// MustWaitUncommitted waits for at least a given uncommitted index. Panic on error.
func (l *Log) MustWaitUncommitted(index uint64) {
if err := l.Log.WaitUncommitted(index); err != nil {
panic(l.URL.String() + " wait uncommitted: " + err.Error())
}
}
// FSM represents a simple state machine that records all commands.
@ -418,7 +433,10 @@ func (fsm *FSM) Restore(r io.Reader) error {
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
return json.Unmarshal(buf, &fsm)
if err := json.Unmarshal(buf, &fsm); err != nil {
return err
}
return nil
}
// MockFSM represents a state machine that can be mocked out.
@ -462,5 +480,14 @@ func jsonify(v interface{}) string {
return string(b)
}
func warn(v ...interface{}) { fmt.Fprintln(os.Stderr, v...) }
func warnf(msg string, v ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", v...) }
func warn(v ...interface{}) {
if testing.Verbose() {
fmt.Fprintln(os.Stderr, v...)
}
}
func warnf(msg string, v ...interface{}) {
if testing.Verbose() {
fmt.Fprintf(os.Stderr, msg+"\n", v...)
}
}

View File

@ -11,82 +11,6 @@ import (
"strconv"
)
// Initializes the default transport to support standard HTTP and TCP.
func init() {
t := NewTransportMux()
t.Handle("http", &HTTPTransport{})
DefaultTransport = t
}
// Transport represents a handler for connecting the log to another node.
// It uses URLs to direct requests over different protocols.
type Transport interface {
Join(u *url.URL, nodeURL *url.URL) (uint64, *Config, error)
Leave(u *url.URL, id uint64) error
Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (lastIndex, currentTerm uint64, err error)
ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser, error)
RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error)
}
// DefaultTransport provides support for HTTP and TCP protocols.
var DefaultTransport Transport
// Transport is a transport multiplexer. It takes incoming requests and delegates
// them to the matching transport implementation based on their URL scheme.
type TransportMux struct {
m map[string]Transport
}
// NewTransportMux returns a new instance of TransportMux.
func NewTransportMux() *TransportMux {
return &TransportMux{m: make(map[string]Transport)}
}
// Handle registers a transport for a given scheme.
func (mux *TransportMux) Handle(scheme string, t Transport) {
mux.m[scheme] = t
}
// Join requests membership into a node's cluster.
func (mux *TransportMux) Join(u *url.URL, nodeURL *url.URL) (uint64, *Config, error) {
if t, ok := mux.m[u.Scheme]; ok {
return t.Join(u, nodeURL)
}
return 0, nil, fmt.Errorf("transport scheme not supported: %s", u.Scheme)
}
// Leave removes a node from a cluster's membership.
func (mux *TransportMux) Leave(u *url.URL, id uint64) error {
if t, ok := mux.m[u.Scheme]; ok {
return t.Leave(u, id)
}
return fmt.Errorf("transport scheme not supported: %s", u.Scheme)
}
// Heartbeat checks the status of a follower.
func (mux *TransportMux) Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (uint64, uint64, error) {
if t, ok := mux.m[u.Scheme]; ok {
return t.Heartbeat(u, term, commitIndex, leaderID)
}
return 0, 0, fmt.Errorf("transport scheme not supported: %s", u.Scheme)
}
// ReadFrom streams the log from a leader.
func (mux *TransportMux) ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser, error) {
if t, ok := mux.m[u.Scheme]; ok {
return t.ReadFrom(u, id, term, index)
}
return nil, fmt.Errorf("transport scheme not supported: %s", u.Scheme)
}
// RequestVote requests a vote for a candidate in a given term.
func (mux *TransportMux) RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
if t, ok := mux.m[u.Scheme]; ok {
return t.RequestVote(u, term, candidateID, lastLogIndex, lastLogTerm)
}
return 0, fmt.Errorf("transport scheme not supported: %s", u.Scheme)
}
// HTTPTransport represents a transport for sending RPCs over the HTTP protocol.
type HTTPTransport struct{}

View File

@ -1,5 +1,24 @@
package raft_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/url"
"sync"
"testing"
"time"
// "net/http"
// "net/http/httptest"
// "strings"
// "testing"
"github.com/influxdb/influxdb/raft"
)
/*
import (
"io/ioutil"
"net/http"
@ -11,41 +30,6 @@ import (
"github.com/influxdb/influxdb/raft"
)
// Ensure a join on an unsupported scheme returns an error.
func TestTransportMux_Join_ErrUnsupportedScheme(t *testing.T) {
u, _ := url.Parse("foo://bar")
_, _, err := raft.DefaultTransport.Join(u, nil)
if err == nil || err.Error() != `transport scheme not supported: foo` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a heartbeat on an unsupported scheme returns an error.
func TestTransportMux_Heartbeat_ErrUnsupportedScheme(t *testing.T) {
u, _ := url.Parse("foo://bar")
_, _, err := raft.DefaultTransport.Heartbeat(u, 0, 0, 0)
if err == nil || err.Error() != `transport scheme not supported: foo` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a stream on an unsupported scheme returns an error.
func TestTransportMux_ReadFrom_ErrUnsupportedScheme(t *testing.T) {
u, _ := url.Parse("foo://bar")
_, err := raft.DefaultTransport.ReadFrom(u, 0, 0, 0)
if err == nil || err.Error() != `transport scheme not supported: foo` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a stream on an unsupported scheme returns an error.
func TestTransportMux_RequestVote_ErrUnsupportedScheme(t *testing.T) {
u, _ := url.Parse("foo://bar")
_, err := raft.DefaultTransport.RequestVote(u, 0, 0, 0, 0)
if err == nil || err.Error() != `transport scheme not supported: foo` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a heartbeat over HTTP can be read and responded to.
func TestHTTPTransport_Heartbeat(t *testing.T) {
@ -266,3 +250,163 @@ func TestHTTPTransport_RequestVote_ErrConnectionRefused(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
}
*/
// Transport represents a test transport that directly calls another log.
// Logs are looked up by hostname only.
type Transport struct {
logs map[string]*raft.Log // logs by host
}
// NewTransport returns a new instance of Transport.
func NewTransport() *Transport {
return &Transport{logs: make(map[string]*raft.Log)}
}
// register registers a log by hostname.
func (t *Transport) register(l *raft.Log) {
t.logs[l.URL.Host] = l
}
// log returns a log registered by hostname.
func (t *Transport) log(u *url.URL) (*raft.Log, error) {
if l := t.logs[u.Host]; l != nil {
return l, nil
}
return nil, fmt.Errorf("log not found: %s", u.String())
}
// Join calls the AddPeer method on the target log.
func (t *Transport) Join(u *url.URL, nodeURL *url.URL) (uint64, *raft.Config, error) {
l, err := t.log(u)
if err != nil {
return 0, nil, err
}
return l.AddPeer(nodeURL)
}
// Leave calls the RemovePeer method on the target log.
func (t *Transport) Leave(u *url.URL, id uint64) error {
l, err := t.log(u)
if err != nil {
return err
}
return l.RemovePeer(id)
}
// Heartbeat calls the Heartbeat method on the target log.
func (t *Transport) Heartbeat(u *url.URL, term, commitIndex, leaderID uint64) (lastIndex, currentTerm uint64, err error) {
l, err := t.log(u)
if err != nil {
return 0, 0, err
}
return l.Heartbeat(term, commitIndex, leaderID)
}
// ReadFrom streams entries from the target log.
func (t *Transport) ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser, error) {
l, err := t.log(u)
if err != nil {
return nil, err
}
// Create a streaming buffer that will hang until Close() is called.
buf := newStreamingBuffer()
go func() {
if err := l.WriteEntriesTo(buf.buf, id, term, index); err != nil {
warnf("Transport.ReadFrom: error: %s", err)
}
_ = buf.Close()
}()
return buf, nil
}
// RequestVote calls RequestVote() on the target log.
func (t *Transport) RequestVote(u *url.URL, term, candidateID, lastLogIndex, lastLogTerm uint64) (uint64, error) {
l, err := t.log(u)
if err != nil {
return 0, err
}
return l.RequestVote(term, candidateID, lastLogIndex, lastLogTerm)
}
// streamingBuffer implements a streaming bytes buffer.
// This will hang during reads until there is data available or the streamer is closed.
type streamingBuffer struct {
mu sync.Mutex
buf *bytes.Buffer
closed bool
}
// newStreamingBuffer returns a new streamingBuffer.
func newStreamingBuffer() *streamingBuffer {
return &streamingBuffer{buf: bytes.NewBuffer(nil)}
}
// Close marks the buffer as closed.
func (b *streamingBuffer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
return nil
}
// Closed returns true if Close() has been called.
func (b *streamingBuffer) Closed() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.closed
}
func (b *streamingBuffer) Read(p []byte) (n int, err error) {
for {
n, err = b.buf.Read(p)
if err == io.EOF && n > 0 { // hit EOF, read data
return n, nil
} else if err == io.EOF { // hit EOF, no data
// If closed then return EOF.
if b.Closed() {
return n, err
}
// If not closed then wait a bit and try again.
time.Sleep(1 * time.Millisecond)
continue
}
// If we've read data or we've hit a non-EOF error then return.
return n, err
}
}
// Ensure the streaming buffer will continue to stream data, if available, after it's closed.
// This is primarily a santity check to make sure our test buffer isn't causing problems.
func TestStreamingBuffer(t *testing.T) {
// Write some data to buffer.
buf := newStreamingBuffer()
buf.buf.WriteString("foo")
// Read all data out in separate goroutine.
start := make(chan struct{}, 0)
ch := make(chan string, 0)
go func() {
close(start)
b, err := ioutil.ReadAll(buf)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ch <- string(b)
}()
// Wait for reader to kick in.
<-start
// Write some more data and then close.
buf.buf.WriteString("bar")
buf.Close()
// Verify all data was read.
if s := <-ch; s != "foobar" {
t.Fatalf("unexpected output: %s", s)
}
}