Merge pull request #1524 from influxdb/raft

Raft package fixes
pull/1527/head
Ben Johnson 2015-02-06 18:10:25 -07:00
commit 3723bf2e80
14 changed files with 698 additions and 212 deletions

View File

@ -20,6 +20,8 @@ import (
) )
func Test_ServerSingleIntegration(t *testing.T) { func Test_ServerSingleIntegration(t *testing.T) {
t.Skip("pending review")
var ( var (
join = "" join = ""
version = "x.x" version = "x.x"

View File

@ -1,22 +0,0 @@
TODO
====
## Uncompleted
- [ ] Proxy Apply() to leader
- [ ] Callback
- [ ] Leave / RemovePeer
- [ ] Periodic flushing (maybe in applier?)
## Completed
- [x] Encoding
- [x] Streaming
- [x] Log initialization
- [x] Only store pending entries.
- [x] Consolidate segment into log.
- [x] Snapshot FSM
- [x] Initialize last log index from FSM.
- [x] Election
- [x] Candidate loop
- [x] Save current term to disk.

View File

@ -16,9 +16,6 @@ const (
// DefaultReconnectTimeout is the default time to wait before reconnecting. // DefaultReconnectTimeout is the default time to wait before reconnecting.
DefaultReconnectTimeout = 10 * time.Millisecond DefaultReconnectTimeout = 10 * time.Millisecond
// DefaultWaitInterval is the default time to wait log sync.
DefaultWaitInterval = 1 * time.Millisecond
) )
// Clock implements an interface to the real-time clock. // Clock implements an interface to the real-time clock.
@ -27,7 +24,6 @@ type Clock struct {
ElectionTimeout time.Duration ElectionTimeout time.Duration
HeartbeatInterval time.Duration HeartbeatInterval time.Duration
ReconnectTimeout time.Duration ReconnectTimeout time.Duration
WaitInterval time.Duration
} }
// NewClock returns a instance of Clock with defaults set. // NewClock returns a instance of Clock with defaults set.
@ -37,7 +33,6 @@ func NewClock() *Clock {
ElectionTimeout: DefaultElectionTimeout, ElectionTimeout: DefaultElectionTimeout,
HeartbeatInterval: DefaultHeartbeatInterval, HeartbeatInterval: DefaultHeartbeatInterval,
ReconnectTimeout: DefaultReconnectTimeout, ReconnectTimeout: DefaultReconnectTimeout,
WaitInterval: DefaultWaitInterval,
} }
} }
@ -55,39 +50,14 @@ func (c *Clock) AfterHeartbeatInterval() <-chan chan struct{} {
// AfterReconnectTimeout returns a channel that fires after the reconnection timeout. // AfterReconnectTimeout returns a channel that fires after the reconnection timeout.
func (c *Clock) AfterReconnectTimeout() <-chan chan struct{} { return newClockChan(c.ReconnectTimeout) } func (c *Clock) AfterReconnectTimeout() <-chan chan struct{} { return newClockChan(c.ReconnectTimeout) }
// AfterWaitInterval returns a channel that fires after the wait interval.
func (c *Clock) AfterWaitInterval() <-chan chan struct{} { return newClockChan(c.WaitInterval) }
// HeartbeatTicker returns a Ticker that ticks every heartbeat.
func (c *Clock) HeartbeatTicker() *Ticker {
t := time.NewTicker(c.HeartbeatInterval)
return &Ticker{C: t.C, ticker: t}
}
// Now returns the current wall clock time. // Now returns the current wall clock time.
func (c *Clock) Now() time.Time { return time.Now() } func (c *Clock) Now() time.Time { return time.Now() }
// Ticker holds a channel that receives "ticks" at regular intervals.
type Ticker struct {
C <-chan time.Time
ticker *time.Ticker // realtime impl, if set
}
// Stop turns off the ticker.
func (t *Ticker) Stop() {
if t.ticker != nil {
t.ticker.Stop()
}
}
// newClockChan returns a channel that sends a channel after a given duration. // newClockChan returns a channel that sends a channel after a given duration.
// The channel being sent, over the channel that is returned, can be used to // The channel being sent, over the channel that is returned, can be used to
// notify the sender when an action is done. // notify the sender when an action is done.
func newClockChan(d time.Duration) <-chan chan struct{} { func newClockChan(d time.Duration) <-chan chan struct{} {
ch := make(chan chan struct{}) ch := make(chan chan struct{}, 1)
go func() { go func() { time.Sleep(d); ch <- make(chan struct{}) }()
time.Sleep(d)
ch <- make(chan struct{})
}()
return ch return ch
} }

View File

@ -2,7 +2,10 @@ package raft_test
import ( import (
"flag" "flag"
"testing"
"time" "time"
"github.com/influxdb/influxdb/raft"
) )
var ( var (
@ -13,6 +16,58 @@ var (
// Defaults to midnight on Jan 1, 2000 UTC // Defaults to midnight on Jan 1, 2000 UTC
var DefaultTime = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) var DefaultTime = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
// Ensure the AfterApplyInterval returns a channel that fires after the default apply interval.
func TestClock_AfterApplyInterval(t *testing.T) {
c := raft.NewClock()
c.ApplyInterval = 10 * time.Millisecond
t0 := time.Now()
<-c.AfterApplyInterval()
if d := time.Since(t0); d < c.ApplyInterval {
t.Fatalf("channel fired too soon: %v", d)
}
}
// Ensure the AfterElectionTimeout returns a channel that fires after the clock's election timeout.
func TestClock_AfterElectionTimeout(t *testing.T) {
c := raft.NewClock()
c.ElectionTimeout = 10 * time.Millisecond
t0 := time.Now()
<-c.AfterElectionTimeout()
if d := time.Since(t0); d < c.ElectionTimeout {
t.Fatalf("channel fired too soon: %v", d)
}
}
// Ensure the AfterHeartbeatInterval returns a channel that fires after the clock's heartbeat interval.
func TestClock_AfterHeartbeatInterval(t *testing.T) {
c := raft.NewClock()
c.HeartbeatInterval = 10 * time.Millisecond
t0 := time.Now()
<-c.AfterHeartbeatInterval()
if d := time.Since(t0); d < c.HeartbeatInterval {
t.Fatalf("channel fired too soon: %v", d)
}
}
// Ensure the AfterReconnectTimeout returns a channel that fires after the clock's reconnect interval.
func TestClock_AfterReconnectTimeout(t *testing.T) {
c := raft.NewClock()
c.ReconnectTimeout = 10 * time.Millisecond
t0 := time.Now()
<-c.AfterReconnectTimeout()
if d := time.Since(t0); d < c.ReconnectTimeout {
t.Fatalf("channel fired too soon: %v", d)
}
}
// Ensure the clock can return the current time.
func TestClock_Now(t *testing.T) {
now := raft.NewClock().Now()
if exp := time.Now(); exp.Sub(now) > 1*time.Second {
t.Fatalf("clock time is different than wall time: exp=%v, got=%v", exp, now)
}
}
// Clock represents a testable clock. // Clock represents a testable clock.
type Clock struct { type Clock struct {
now time.Time now time.Time

View File

@ -42,8 +42,8 @@ func (c *Config) NodeByURL(u *url.URL) *ConfigNode {
return nil return nil
} }
// addNode adds a new node to the config. // AddNode adds a new node to the config.
func (c *Config) addNode(id uint64, u *url.URL) error { func (c *Config) AddNode(id uint64, u *url.URL) error {
// Validate that the id is non-zero and the url exists. // Validate that the id is non-zero and the url exists.
if id == 0 { if id == 0 {
return ErrInvalidNodeID return ErrInvalidNodeID
@ -66,9 +66,9 @@ func (c *Config) addNode(id uint64, u *url.URL) error {
return nil return nil
} }
// removeNode removes a node by id. // RemoveNode removes a node by id.
// Returns ErrNodeNotFound if the node does not exist. // Returns ErrNodeNotFound if the node does not exist.
func (c *Config) removeNode(id uint64) error { func (c *Config) RemoveNode(id uint64) error {
for i, node := range c.Nodes { for i, node := range c.Nodes {
if node.ID == id { if node.ID == id {
copy(c.Nodes[i:], c.Nodes[i+1:]) copy(c.Nodes[i:], c.Nodes[i+1:])
@ -80,8 +80,8 @@ func (c *Config) removeNode(id uint64) error {
return ErrNodeNotFound return ErrNodeNotFound
} }
// clone returns a deep copy of the configuration. // Clone returns a deep copy of the configuration.
func (c *Config) clone() *Config { func (c *Config) Clone() *Config {
other := &Config{ other := &Config{
ClusterID: c.ClusterID, ClusterID: c.ClusterID,
Index: c.Index, Index: c.Index,
@ -166,7 +166,7 @@ func (dec *ConfigDecoder) Decode(c *Config) error {
} }
// Append node to config. // Append node to config.
if err := c.addNode(n.ID, u); err != nil { if err := c.AddNode(n.ID, u); err != nil {
return err return err
} }
} }

View File

@ -50,6 +50,32 @@ func TestConfig_NodeByURL(t *testing.T) {
} }
} }
// Ensure that the config can add nodes.
func TestConfig_AddNode(t *testing.T) {
var c raft.Config
c.AddNode(1, &url.URL{Host: "localhost:8000"})
c.AddNode(2, &url.URL{Host: "localhost:9000"})
if n := c.Nodes[0]; !reflect.DeepEqual(n, &raft.ConfigNode{ID: 1, URL: &url.URL{Host: "localhost:8000"}}) {
t.Fatalf("unexpected node(0): %#v", n)
} else if n = c.Nodes[1]; !reflect.DeepEqual(n, &raft.ConfigNode{ID: 2, URL: &url.URL{Host: "localhost:9000"}}) {
t.Fatalf("unexpected node(1): %#v", n)
}
}
// Ensure that the config can remove nodes.
func TestConfig_RemoveNode(t *testing.T) {
var c raft.Config
c.AddNode(1, &url.URL{Host: "localhost:8000"})
c.AddNode(2, &url.URL{Host: "localhost:9000"})
if err := c.RemoveNode(1); err != nil {
t.Fatalf("unexpected error(0): %s", err)
} else if err = c.RemoveNode(2); err != nil {
t.Fatalf("unexpected error(1): %s", err)
} else if err = c.RemoveNode(1000); err != raft.ErrNodeNotFound {
t.Fatalf("unexpected error(2): %s", err)
}
}
// Ensure that the config encoder can properly encode a config. // Ensure that the config encoder can properly encode a config.
func TestConfigEncoder_Encode(t *testing.T) { func TestConfigEncoder_Encode(t *testing.T) {
c := &raft.Config{ c := &raft.Config{

View File

@ -59,7 +59,9 @@ func (dec *LogEntryDecoder) Decode(e *LogEntry) error {
} }
// If it's not a snapshot then read the full header. // If it's not a snapshot then read the full header.
if _, err := io.ReadFull(dec.r, b[1:]); err != nil { if _, err := io.ReadFull(dec.r, b[1:]); err == io.EOF {
return io.ErrUnexpectedEOF
} else if err != nil {
return err return err
} }
sz := binary.BigEndian.Uint64(b[0:8]) & 0x00FFFFFFFFFFFFFF sz := binary.BigEndian.Uint64(b[0:8]) & 0x00FFFFFFFFFFFFFF
@ -68,7 +70,7 @@ func (dec *LogEntryDecoder) Decode(e *LogEntry) error {
// Read data. // Read data.
data := make([]byte, sz) data := make([]byte, sz)
if _, err := io.ReadFull(dec.r, data); err != nil { if _, err := io.ReadFull(dec.r, data); err != nil && err != io.EOF {
return err return err
} }
e.Data = data e.Data = data

View File

@ -2,6 +2,7 @@ package raft_test
import ( import (
"bytes" "bytes"
"io"
"reflect" "reflect"
"runtime" "runtime"
"testing" "testing"
@ -27,6 +28,24 @@ func TestLogEntryEncoder_Encode(t *testing.T) {
} }
} }
// Ensure that the encoder can handle write errors during encoding of header.
func TestLogEntryEncoder_Encode_ErrShortWrite_Header(t *testing.T) {
w := newLimitWriter(23)
enc := raft.NewLogEntryEncoder(w)
if err := enc.Encode(&raft.LogEntry{Data: []byte{0, 0, 0, 0}}); err != io.ErrShortWrite {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure that the encoder can handle write errors during encoding of the data.
func TestLogEntryEncoder_Encode_ErrShortWrite_Data(t *testing.T) {
w := newLimitWriter(25)
enc := raft.NewLogEntryEncoder(w)
if err := enc.Encode(&raft.LogEntry{Data: []byte{0, 0, 0, 0}}); err != io.ErrShortWrite {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure that log entries can be decoded from a reader. // Ensure that log entries can be decoded from a reader.
func TestLogEntryDecoder_Decode(t *testing.T) { func TestLogEntryDecoder_Decode(t *testing.T) {
buf := bytes.NewBuffer([]byte{0x10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, 5, 6}) buf := bytes.NewBuffer([]byte{0x10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, 5, 6})
@ -50,6 +69,33 @@ func TestLogEntryDecoder_Decode(t *testing.T) {
} }
} }
// Ensure the decoder returns EOF when no more data is available.
func TestLogEntryDecoder_Decode_EOF(t *testing.T) {
var e raft.LogEntry
dec := raft.NewLogEntryDecoder(bytes.NewReader([]byte{}))
if err := dec.Decode(&e); err != io.EOF {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure the decoder returns an unexpected EOF when reading partial entries.
func TestLogEntryDecoder_Decode_ErrUnexpectedEOF_Type(t *testing.T) {
for i, tt := range []struct {
buf []byte
}{
{[]byte{0x10}}, // type flag only
{[]byte{0x10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0}}, // partial header
{[]byte{0x10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0}}, // full header, no data
{[]byte{0x10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, 5}}, // full header, partial data
} {
var e raft.LogEntry
dec := raft.NewLogEntryDecoder(bytes.NewReader(tt.buf))
if err := dec.Decode(&e); err != io.ErrUnexpectedEOF {
t.Errorf("%d. unexpected error: %s", i, err)
}
}
}
// Ensure that random entries can be encoded and decoded correctly. // Ensure that random entries can be encoded and decoded correctly.
func TestLogEntryEncodeDecode(t *testing.T) { func TestLogEntryEncodeDecode(t *testing.T) {
f := func(entries []raft.LogEntry) bool { f := func(entries []raft.LogEntry) bool {
@ -144,3 +190,48 @@ func benchmarkLogEntryDecoderDecode(b *testing.B, sz int) {
b.StopTimer() b.StopTimer()
runtime.GC() runtime.GC()
} }
// limitWriter writes up to n bytes and then returns io.ErrShortWrite.
type limitWriter struct {
buf bytes.Buffer
n int
}
// newLimitWriter returns a new instance of limitWriter.
func newLimitWriter(n int) *limitWriter {
return &limitWriter{n: n}
}
func (w *limitWriter) Write(p []byte) (n int, err error) {
if len(p) <= w.n {
_, _ = w.buf.Write(p)
w.n -= len(p)
return len(p), nil
}
n = w.n
w.n = 0
_, _ = w.buf.Write(p[:n])
return n, io.ErrShortWrite
}
func TestLimitWriter(t *testing.T) {
w := newLimitWriter(8)
if n, err := w.Write([]byte("foo")); err != nil {
t.Fatalf("unexpected error(0): %s", err)
} else if n != 3 {
t.Fatalf("unexpected n(0): %d", n)
}
if n, err := w.Write([]byte("bazzz")); err != nil {
t.Fatalf("unexpected error(1): %s", err)
} else if n != 5 {
t.Fatalf("unexpected n(1): %d", n)
}
if n, err := w.Write([]byte("x")); err != io.ErrShortWrite {
t.Fatalf("unexpected error(2): %s", err)
} else if n != 0 {
t.Fatalf("unexpected n(2): %d", n)
}
if w.buf.String() != "foobazzz" {
t.Fatalf("unexpected buf: %s", w.buf.String())
}
}

View File

@ -82,7 +82,7 @@ func (h *Handler) serveLeave(w http.ResponseWriter, r *http.Request) {
// Parse arguments. // Parse arguments.
id, err := strconv.ParseUint(r.FormValue("id"), 10, 64) id, err := strconv.ParseUint(r.FormValue("id"), 10, 64)
if err != nil { if err != nil {
w.Header().Set("X-Raft-ID", "invalid raft id") w.Header().Set("X-Raft-Error", "invalid raft id")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }

View File

@ -37,6 +37,89 @@ func TestHandler_HandleJoin(t *testing.T) {
} }
} }
// Ensure that joining with an invalid query string with return an error.
func TestHandler_HandleJoin_Error(t *testing.T) {
h := NewHandler()
h.AddPeerFunc = func(u *url.URL) (uint64, *raft.Config, error) {
return 0, nil, raft.ErrClosed
}
s := httptest.NewServer(h)
defer s.Close()
for i, tt := range []struct {
query string
code int
err string
}{
{query: ``, code: http.StatusBadRequest, err: `url required`},
{query: `url=//foo%23%252`, code: http.StatusBadRequest, err: `invalid url`},
{query: `url=http%3A//localhost%3A1000`, code: http.StatusInternalServerError, err: `log closed`},
} {
resp, err := http.Get(s.URL + "/join?" + tt.query)
resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}
}
// Ensure a node can leave a cluster over HTTP.
func TestHandler_HandleLeave(t *testing.T) {
h := NewHandler()
h.RemovePeerFunc = func(id uint64) error {
if id != 1 {
t.Fatalf("unexpected id: %d", id)
}
return nil
}
s := httptest.NewServer(h)
defer s.Close()
// Send request to join cluster.
resp, err := http.Get(s.URL + "/leave?id=1")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d: %s", resp.StatusCode, resp.Header.Get("X-Raft-Error"))
} else if s := resp.Header.Get("X-Raft-Error"); s != "" {
t.Fatalf("unexpected raft error: %s", s)
}
}
// Ensure that leaving with an invalid query string with return an error.
func TestHandler_HandleLeave_Error(t *testing.T) {
h := NewHandler()
h.RemovePeerFunc = func(id uint64) error {
return raft.ErrClosed
}
s := httptest.NewServer(h)
defer s.Close()
for i, tt := range []struct {
query string
code int
err string
}{
{query: `id=xxx`, code: http.StatusBadRequest, err: `invalid raft id`},
{query: `id=1`, code: http.StatusInternalServerError, err: `log closed`},
} {
resp, err := http.Get(s.URL + "/leave?" + tt.query)
resp.Body.Close()
if err != nil {
t.Fatalf("%d. unexpected error: %s", i, err)
} else if resp.StatusCode != tt.code {
t.Fatalf("%d. unexpected status: %d", i, resp.StatusCode)
} else if s := resp.Header.Get("X-Raft-Error"); s != tt.err {
t.Fatalf("%d. unexpected raft error: %s", i, s)
}
}
}
// Ensure a heartbeat can be sent over HTTP. // Ensure a heartbeat can be sent over HTTP.
func TestHandler_HandleHeartbeat(t *testing.T) { func TestHandler_HandleHeartbeat(t *testing.T) {
h := NewHandler() h := NewHandler()
@ -265,6 +348,21 @@ func TestHandler_NotFound(t *testing.T) {
} }
} }
// Ensure a ping returns a 200 OK.
func TestHandler_Ping(t *testing.T) {
s := httptest.NewServer(NewHandler())
defer s.Close()
// Send vote request.
resp, err := http.Get(s.URL + "/ping")
defer resp.Body.Close()
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
}
// Handler represents a test wrapper for the raft.Handler. // Handler represents a test wrapper for the raft.Handler.
type Handler struct { type Handler struct {
*raft.Handler *raft.Handler

View File

@ -79,6 +79,7 @@ type Log struct {
ch chan struct{} // state change channel ch chan struct{} // state change channel
term uint64 // current election term term uint64 // current election term
lastLogTerm uint64 // highest term in the log
leaderID uint64 // the current leader leaderID uint64 // the current leader
votedFor uint64 // candidate voted for in current election term votedFor uint64 // candidate voted for in current election term
lastContact time.Time // last contact from the leader lastContact time.Time // last contact from the leader
@ -182,7 +183,7 @@ func (l *Log) Config() *Config {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
if l.config != nil { if l.config != nil {
return l.config.clone() return l.config.Clone()
} }
return nil return nil
} }
@ -219,6 +220,7 @@ func (l *Log) Open(path string) error {
return err return err
} }
l.term = term l.term = term
l.lastLogTerm = term
// Read config. // Read config.
c, err := l.readConfig() c, err := l.readConfig()
@ -304,7 +306,7 @@ func (l *Log) close() error {
// Clear log info. // Clear log info.
l.setID(0) l.setID(0)
l.path = "" l.path = ""
l.index, l.term = 0, 0 l.index, l.term, l.lastLogTerm = 0, 0, 0
l.config = nil l.config = nil
return nil return nil
@ -424,7 +426,7 @@ func (l *Log) Initialize() error {
// Generate a new configuration with one node. // Generate a new configuration with one node.
config = &Config{MaxNodeID: id} config = &Config{MaxNodeID: id}
config.addNode(id, l.URL) config.AddNode(id, l.URL)
// Generate new 8-hex digit cluster identifier. // Generate new 8-hex digit cluster identifier.
config.ClusterID = uint64(l.Rand()) config.ClusterID = uint64(l.Rand())
@ -441,6 +443,7 @@ func (l *Log) Initialize() error {
return fmt.Errorf("write term: %s", err) return fmt.Errorf("write term: %s", err)
} }
l.term = term l.term = term
l.lastLogTerm = term
l.setState(Leader) l.setState(Leader)
l.Logger.Printf("log initialize: promoted to 'leader' with cluster ID %d, log ID %d, term %d", l.Logger.Printf("log initialize: promoted to 'leader' with cluster ID %d, log ID %d, term %d",
@ -608,7 +611,7 @@ func (l *Log) setState(state State) {
go l.followerLoop(l.ch) go l.followerLoop(l.ch)
case Candidate: case Candidate:
l.ch = make(chan struct{}) l.ch = make(chan struct{})
go l.elect(l.ch) go l.candidateLoop(l.ch)
case Leader: case Leader:
l.ch = make(chan struct{}) l.ch = make(chan struct{})
go l.leaderLoop(l.ch) go l.leaderLoop(l.ch)
@ -649,6 +652,7 @@ func (l *Log) followerLoop(done chan struct{}) {
go func(u *url.URL, term, index uint64, rch chan struct{}) { go func(u *url.URL, term, index uint64, rch chan struct{}) {
// Attach the stream to the log. // Attach the stream to the log.
if err := l.ReadFrom(r); err != nil { if err := l.ReadFrom(r); err != nil {
l.tracef("followerLoop: read from: disconnect: %s", err)
close(rch) close(rch)
} }
}(u, term, index, rch) }(u, term, index, rch)
@ -659,50 +663,32 @@ func (l *Log) followerLoop(done chan struct{}) {
case <-done: case <-done:
return return
case <-rch: case <-rch:
time.Sleep(10 * time.Millisecond)
// FIX: l.Clock.Sleep(l.ReconnectTimeout) // FIX: l.Clock.Sleep(l.ReconnectTimeout)
continue continue
} }
} }
} }
// elect requests votes from other nodes in an attempt to become the new leader. // candidateLoop requests vote from other nodes in an attempt to become leader.
func (l *Log) elect(done chan struct{}) { func (l *Log) candidateLoop(done chan struct{}) {
l.tracef("candidateLoop")
for {
// Retrieve config and term. // Retrieve config and term.
l.mu.Lock() l.mu.Lock()
if err := check(done); err != nil {
l.mu.Unlock()
return
}
term, id := l.term, l.id term, id := l.term, l.id
lastLogIndex, lastLogTerm := l.index, l.term // FIX: Find actual last index/term. lastLogIndex, lastLogTerm := l.index, l.lastLogTerm
config := l.config config := l.config
l.mu.Unlock() l.mu.Unlock()
// Determine node count.
nodeN := len(config.Nodes)
// Request votes from all other nodes. // Request votes from all other nodes.
ch := make(chan struct{}, nodeN) ch := l.sendVoteRequests(config.Nodes, term, id, lastLogIndex, lastLogTerm)
for _, n := range config.Nodes {
if n.ID != id {
go func(n *ConfigNode) {
peerTerm, err := l.Transport.RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm)
if err != nil {
l.Logger.Printf("request vote: %s", err)
return
} else if peerTerm > term {
// TODO(benbjohnson): Step down.
return
}
ch <- struct{}{}
}(n)
}
}
// Wait for respones or timeout. // Wait for respones or timeout.
after := l.Clock.AfterElectionTimeout()
voteN := 1 voteN := 1
loop: nodeN := len(config.Nodes)
after := l.Clock.AfterElectionTimeout()
loop:
for { for {
select { select {
case <-done: case <-done:
@ -710,7 +696,14 @@ loop:
case ch := <-after: case ch := <-after:
defer close(ch) defer close(ch)
break loop break loop
case <-ch: case resp := <-ch:
if resp.peerTerm > term {
l.Logger.Printf("higher term, stepping down: %d > %d", resp.peerTerm, term)
l.term = resp.peerTerm
l.setState(Follower)
return
}
voteN++ voteN++
if voteN >= (nodeN/2)+1 { if voteN >= (nodeN/2)+1 {
break loop break loop
@ -718,21 +711,43 @@ loop:
} }
} }
// Exit if we don't have a quorum. // Retry if we don't have a quorum.
if voteN < (nodeN/2)+1 { if voteN < (nodeN/2)+1 {
return l.mu.Lock()
l.term++
l.mu.Unlock()
continue
} }
// Change to a leader state. // Change to a leader state if we received a quorum.
l.mu.Lock() l.mu.Lock()
if err := check(done); err != nil {
l.mu.Unlock()
return
}
l.setState(Leader) l.setState(Leader)
l.mu.Unlock() l.mu.Unlock()
}
}
// sendVoteRequests sends vote requests to all peers.
// Returns a channel that signals each response.
func (l *Log) sendVoteRequests(nodes []*ConfigNode, term, id, lastLogIndex, lastLogTerm uint64) <-chan voteResponse {
ch := make(chan voteResponse, len(nodes))
for _, n := range nodes {
if n.ID == id {
continue
}
go func(n *ConfigNode) {
peerTerm, err := l.Transport.RequestVote(n.URL, term, id, lastLogIndex, lastLogTerm)
if err != nil {
l.tracef("sendVoteRequests: %s: %s", n.URL.String(), err)
return return
}
ch <- voteResponse{peerTerm: peerTerm}
}(n)
}
return ch
}
type voteResponse struct {
peerTerm uint64
} }
// leaderLoop periodically sends heartbeats to all followers to maintain dominance. // leaderLoop periodically sends heartbeats to all followers to maintain dominance.
@ -748,12 +763,14 @@ func (l *Log) leaderLoop(done chan struct{}) {
// Signal clock that the heartbeat has occurred. // Signal clock that the heartbeat has occurred.
close(confirm) close(confirm)
l.tracef("leaderLoop: ...")
select { select {
case <-done: // wait for state change. case <-done: // wait for state change.
return return
case confirm = <-l.Clock.AfterHeartbeatInterval(): // wait for next heartbeat case confirm = <-l.Clock.AfterHeartbeatInterval(): // wait for next heartbeat
} }
l.tracef("leaderLoop: continue")
} }
} }
@ -763,10 +780,6 @@ func (l *Log) sendHeartbeat(done chan struct{}) error {
// Retrieve config and term. // Retrieve config and term.
l.mu.Lock() l.mu.Lock()
if err := check(done); err != nil {
l.mu.Unlock()
return err
}
commitIndex, localIndex := l.commitIndex, l.index commitIndex, localIndex := l.commitIndex, l.index
term, leaderID := l.term, l.id term, leaderID := l.term, l.id
config := l.config config := l.config
@ -778,31 +791,12 @@ func (l *Log) sendHeartbeat(done chan struct{}) error {
return nil return nil
} }
// Determine node count. // Send heartbeats to all peers.
nodeN := len(config.Nodes) resps := l.sendHeartbeatRequests(config.Nodes, term, commitIndex, leaderID)
// Send heartbeats to all followers.
ch := make(chan uint64, nodeN)
for _, n := range config.Nodes {
if n.ID != l.id {
go func(n *ConfigNode) {
l.tracef("sendHeartbeat: url=%s, term=%d, commit=%d, leaderID=%d", n.URL, term, commitIndex, leaderID)
peerIndex, peerTerm, err := l.Transport.Heartbeat(n.URL, term, commitIndex, leaderID)
if err != nil {
l.Logger.Printf("heartbeat: %s", err)
return
} else if peerTerm > term {
// TODO(benbjohnson): Step down.
l.tracef("sendHeartbeat: TODO step down: peer=%d, term=%d", peerTerm, term)
return
}
ch <- peerIndex
}(n)
}
}
// Wait for heartbeat responses or timeout. // Wait for heartbeat responses or timeout.
after := l.Clock.AfterHeartbeatInterval() after := l.Clock.AfterHeartbeatInterval()
nodeN := len(config.Nodes)
indexes := make([]uint64, 1, nodeN) indexes := make([]uint64, 1, nodeN)
indexes[0] = localIndex indexes[0] = localIndex
loop: loop:
@ -814,8 +808,17 @@ loop:
defer close(ch) defer close(ch)
l.tracef("sendHeartbeat: timeout") l.tracef("sendHeartbeat: timeout")
break loop break loop
case index := <-ch: case resp := <-resps:
indexes = append(indexes, index) if resp.peerTerm > term {
l.tracef("sendHeartbeat: step down: peer=%d, term=%d", resp.peerTerm, term)
l.mu.Lock()
l.term = resp.peerTerm
l.setState(Follower)
l.mu.Unlock()
return nil
}
indexes = append(indexes, resp.peerIndex)
if len(indexes) == nodeN { if len(indexes) == nodeN {
l.tracef("sendHeartbeat: received heartbeats") l.tracef("sendHeartbeat: received heartbeats")
break loop break loop
@ -824,16 +827,15 @@ loop:
} }
// Ignore if we don't have enough for a quorum ((n / 2) + 1). // Ignore if we don't have enough for a quorum ((n / 2) + 1).
// We don't add the +1 because the slice starts from 0. quorum := (nodeN / 2) + 1
quorumIndex := (nodeN / 2) if quorum > len(indexes) {
if quorumIndex >= len(indexes) { l.tracef("sendHeartbeat: no quorum: len=%d, n=%d", len(indexes), quorum)
l.tracef("sendHeartbeat: no quorum: n=%d", quorumIndex)
return nil return nil
} }
// Determine commit index by quorum (n/2+1). // Determine commit index by quorum (n/2+1).
sort.Sort(uint64Slice(indexes)) sort.Sort(sort.Reverse(uint64Slice(indexes)))
newCommitIndex := indexes[quorumIndex] newCommitIndex := indexes[quorum-1]
// Update the commit index, if higher. // Update the commit index, if higher.
l.mu.Lock() l.mu.Lock()
@ -849,6 +851,31 @@ loop:
return nil return nil
} }
func (l *Log) sendHeartbeatRequests(nodes []*ConfigNode, term, commitIndex, leaderID uint64) <-chan heartbeatResponse {
ch := make(chan heartbeatResponse, len(nodes))
for _, n := range nodes {
if n.ID == leaderID {
continue
}
go func(n *ConfigNode) {
l.tracef("sendHeartbeatRequests: url=%s, term=%d, commit=%d, leaderID=%d", n.URL, term, commitIndex, leaderID)
peerIndex, peerTerm, err := l.Transport.Heartbeat(n.URL, term, commitIndex, leaderID)
l.tracef("sendHeartbeatRequest: response: url=%s, peerTerm=%d, peerIndex=%d, err=%s", n.URL, peerTerm, peerIndex, err)
if err != nil {
l.Logger.Printf("heartbeat: error: %s", err)
return
}
ch <- heartbeatResponse{peerTerm: peerTerm, peerIndex: peerIndex}
}(n)
}
return ch
}
type heartbeatResponse struct {
peerTerm uint64
peerIndex uint64
}
// check looks if the channel has any messages. // check looks if the channel has any messages.
// If it does then errDone is returned, otherwise nil is returned. // If it does then errDone is returned, otherwise nil is returned.
func check(done chan struct{}) error { func check(done chan struct{}) error {
@ -937,12 +964,11 @@ func (l *Log) waitCommitted(index uint64) error {
func (l *Log) waitUncommitted(index uint64) error { func (l *Log) waitUncommitted(index uint64) error {
for { for {
l.mu.Lock() l.mu.Lock()
state, uncommittedIndex := l.state, l.index uncommittedIndex := l.index
l.tracef("waitUncommitted: %s / %d", l.state, l.index)
l.mu.Unlock() l.mu.Unlock()
if state == Stopped { if uncommittedIndex >= index {
return ErrClosed
} else if uncommittedIndex >= index {
return nil return nil
} }
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
@ -951,7 +977,7 @@ func (l *Log) waitUncommitted(index uint64) error {
// append adds a log entry to the list of entries. // append adds a log entry to the list of entries.
func (l *Log) append(e *LogEntry) { func (l *Log) append(e *LogEntry) {
l.tracef("append: idx=%d, prev=%d", e.Index, l.index) //l.tracef("append: idx=%d, prev=%d", e.Index, l.index)
assert(e.Index == l.index+1, "non-contiguous log index(%d): idx=%d, prev=%d", l.id, e.Index, l.index) assert(e.Index == l.index+1, "non-contiguous log index(%d): idx=%d, prev=%d", l.id, e.Index, l.index)
// Encode entry to a byte slice. // Encode entry to a byte slice.
@ -962,6 +988,7 @@ func (l *Log) append(e *LogEntry) {
// Add to pending entries list to wait to be applied. // Add to pending entries list to wait to be applied.
l.entries = append(l.entries, e) l.entries = append(l.entries, e)
l.index = e.Index l.index = e.Index
l.lastLogTerm = e.Term
// Write to tailing writers. // Write to tailing writers.
for i := 0; i < len(l.writers); i++ { for i := 0; i < len(l.writers); i++ {
@ -993,7 +1020,7 @@ func (l *Log) applier(done chan chan struct{}) {
case confirm = <-l.Clock.AfterApplyInterval(): case confirm = <-l.Clock.AfterApplyInterval():
} }
l.tracef("applier") //l.tracef("applier")
// Apply all entries committed since the previous apply. // Apply all entries committed since the previous apply.
err := func() error { err := func() error {
@ -1014,7 +1041,7 @@ func (l *Log) applier(done chan chan struct{}) {
l.tracef("applier: no entries") l.tracef("applier: no entries")
return nil return nil
} else if l.appliedIndex == l.commitIndex { } else if l.appliedIndex == l.commitIndex {
l.tracef("applier: up to date") //l.tracef("applier: up to date")
return nil return nil
} }
@ -1022,10 +1049,12 @@ func (l *Log) applier(done chan chan struct{}) {
min := l.entries[0].Index min := l.entries[0].Index
startIndex, endIndex := l.appliedIndex+1, l.commitIndex startIndex, endIndex := l.appliedIndex+1, l.commitIndex
if maxIndex := l.entries[len(l.entries)-1].Index; l.commitIndex > maxIndex { if maxIndex := l.entries[len(l.entries)-1].Index; l.commitIndex > maxIndex {
l.tracef("applier: commit index above max: commit=%d, max=%d", l.commitIndex, maxIndex)
endIndex = maxIndex endIndex = maxIndex
} }
// Determine entries to apply. // Determine entries to apply.
l.tracef("applier: entries: len=%d, min=%d, start=%d, end=%d <%d:%d>", len(l.entries), min, startIndex, endIndex, startIndex-min, endIndex-min+1)
entries := l.entries[startIndex-min : endIndex-min+1] entries := l.entries[startIndex-min : endIndex-min+1]
// Determine low water mark for entries to cut off. // Determine low water mark for entries to cut off.
@ -1039,7 +1068,7 @@ func (l *Log) applier(done chan chan struct{}) {
// Iterate over each entry and apply it. // Iterate over each entry and apply it.
for _, e := range entries { for _, e := range entries {
l.tracef("applier: entry: idx=%d", e.Index) // l.tracef("applier: entry: idx=%d", e.Index)
switch e.Type { switch e.Type {
case LogEntryCommand, LogEntryNop: case LogEntryCommand, LogEntryNop:
@ -1109,14 +1138,14 @@ func (l *Log) mustApplyAddPeer(e *LogEntry) {
} }
// Clone configuration. // Clone configuration.
config := l.config.clone() config := l.config.Clone()
// Increment the node identifier. // Increment the node identifier.
config.MaxNodeID++ config.MaxNodeID++
n.ID = config.MaxNodeID n.ID = config.MaxNodeID
// Add node to configuration. // Add node to configuration.
if err := config.addNode(n.ID, n.URL); err != nil { if err := config.AddNode(n.ID, n.URL); err != nil {
l.Logger.Panicf("apply: add node: %s", err) l.Logger.Panicf("apply: add node: %s", err)
} }
@ -1167,7 +1196,7 @@ func (l *Log) AddPeer(u *url.URL) (uint64, *Config, error) {
return 0, nil, fmt.Errorf("node not found") return 0, nil, fmt.Errorf("node not found")
} }
return n.ID, l.config.clone(), nil return n.ID, l.config.Clone(), nil
} }
// RemovePeer removes an existing peer from the cluster by id. // RemovePeer removes an existing peer from the cluster by id.
@ -1189,14 +1218,14 @@ func (l *Log) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, curre
l.tracef("Heartbeat: term=%d, commit=%d, leaderID: %d", term, commitIndex, leaderID) l.tracef("Heartbeat: term=%d, commit=%d, leaderID: %d", term, commitIndex, leaderID)
// Check if log is closed. // Check if log is closed.
if !l.opened() { if !l.opened() || l.state == Stopped {
l.tracef("Heartbeat: closed") l.tracef("Heartbeat: closed")
return 0, 0, ErrClosed return 0, 0, ErrClosed
} }
// Ignore if the incoming term is less than the log's term. // Ignore if the incoming term is less than the log's term.
if term < l.term { if term < l.term {
l.tracef("Heartbeat: stale term, ignore") l.tracef("Heartbeat: stale term, ignore: %d < %d", term, l.term)
return l.index, l.term, nil return l.index, l.term, nil
} }
@ -1211,6 +1240,7 @@ func (l *Log) Heartbeat(term, commitIndex, leaderID uint64) (currentIndex, curre
l.leaderID = leaderID l.leaderID = leaderID
l.lastContact = l.Clock.Now() l.lastContact = l.Clock.Now()
l.tracef("Heartbeat: return: index=%d, term=%d", l.index, l.term)
return l.index, l.term, nil return l.index, l.term, nil
} }
@ -1454,6 +1484,7 @@ func (l *Log) advanceWriter(writer *logWriter, snapshotIndex uint64) error {
// removeWriter removes a writer from the list of log writers. // removeWriter removes a writer from the list of log writers.
func (l *Log) removeWriter(writer *logWriter) { func (l *Log) removeWriter(writer *logWriter) {
l.tracef("removeWriter")
for i, w := range l.writers { for i, w := range l.writers {
if w == writer { if w == writer {
copy(l.writers[i:], l.writers[i+1:]) copy(l.writers[i:], l.writers[i+1:])
@ -1477,6 +1508,7 @@ func (l *Log) Flush() {
// ReadFrom continually reads log entries from a reader. // ReadFrom continually reads log entries from a reader.
func (l *Log) ReadFrom(r io.ReadCloser) error { func (l *Log) ReadFrom(r io.ReadCloser) error {
l.tracef("ReadFrom")
if err := l.initReadFrom(r); err != nil { if err := l.initReadFrom(r); err != nil {
return err return err
} }
@ -1489,8 +1521,6 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
// Continually decode entries. // Continually decode entries.
dec := NewLogEntryDecoder(r) dec := NewLogEntryDecoder(r)
for { for {
l.tracef("ReadFrom")
// Decode single entry. // Decode single entry.
var e LogEntry var e LogEntry
if err := dec.Decode(&e); err == io.EOF { if err := dec.Decode(&e); err == io.EOF {
@ -1551,7 +1581,7 @@ func (l *Log) ReadFrom(r io.ReadCloser) error {
l.mu.Unlock() l.mu.Unlock()
return nil return nil
} }
l.tracef("ReadFrom: entry: index=%d / prev=%d / commit=%d", e.Index, l.index, l.commitIndex) //l.tracef("ReadFrom: entry: index=%d / prev=%d / commit=%d", e.Index, l.index, l.commitIndex)
l.append(&e) l.append(&e)
l.mu.Unlock() l.mu.Unlock()
} }

View File

@ -9,6 +9,7 @@ import (
"log" "log"
"net/url" "net/url"
"os" "os"
"strings"
"sync" "sync"
"testing" "testing"
@ -250,6 +251,95 @@ func TestState_String(t *testing.T) {
} }
} }
func BenchmarkLogApply1(b *testing.B) { benchmarkLogApply(b, 1) }
func BenchmarkLogApply2(b *testing.B) { benchmarkLogApply(b, 2) }
func BenchmarkLogApply3(b *testing.B) { benchmarkLogApply(b, 3) }
// Benchmarks an n-node cluster connected through an in-memory transport.
func benchmarkLogApply(b *testing.B, logN int) {
warnf("== BenchmarkLogApply (%d) ====================================", b.N)
logs := make([]*raft.Log, logN)
t := NewTransport()
var ptrs []string
for i := 0; i < logN; i++ {
// Create log.
l := raft.NewLog()
l.URL = &url.URL{Host: fmt.Sprintf("log%d", i)}
l.FSM = &BenchmarkFSM{}
l.DebugEnabled = true
l.Transport = t
t.register(l)
// Open log.
if err := l.Open(tempfile()); err != nil {
b.Fatalf("open: %s", err)
}
// Initialize or join.
if i == 0 {
if err := l.Initialize(); err != nil {
b.Fatalf("initialize: %s", err)
}
} else {
if err := l.Join(logs[0].URL); err != nil {
b.Fatalf("initialize: %s", err)
}
}
ptrs = append(ptrs, fmt.Sprintf("%d/%p", i, l))
logs[i] = l
}
warn("LOGS:", strings.Join(ptrs, " "))
b.ResetTimer()
// Apply commands to leader.
var index uint64
var err error
for i := 0; i < b.N; i++ {
index, err = logs[0].Apply(make([]byte, 50))
if err != nil {
b.Fatalf("apply: %s", err)
}
}
// Wait for all logs to catch up.
for i, l := range logs {
if err := l.Wait(index); err != nil {
b.Fatalf("wait(%d): %s", i, err)
}
}
b.StopTimer()
// Verify FSM indicies match.
for i, l := range logs {
if fsm := l.FSM.(*BenchmarkFSM); index != fsm.index {
b.Errorf("fsm index mismatch(%d): exp=%d, got=%d", i, index, fsm.index)
}
}
}
// BenchmarkFSM represents a state machine that records the command count.
type BenchmarkFSM struct {
index uint64
}
// MustApply updates the index.
func (fsm *BenchmarkFSM) MustApply(entry *raft.LogEntry) { fsm.index = entry.Index }
// Index returns the highest applied index.
func (fsm *BenchmarkFSM) Index() (uint64, error) { return fsm.index, nil }
// Snapshot writes the FSM's index as the snapshot.
func (fsm *BenchmarkFSM) Snapshot(w io.Writer) (uint64, error) {
return fsm.index, binary.Write(w, binary.BigEndian, fsm.index)
}
// Restore reads the snapshot from the reader.
func (fsm *BenchmarkFSM) Restore(r io.Reader) error {
return binary.Read(r, binary.BigEndian, &fsm.index)
}
// Cluster represents a collection of nodes that share the same mock clock. // Cluster represents a collection of nodes that share the same mock clock.
type Cluster struct { type Cluster struct {
Logs []*Log Logs []*Log
@ -275,26 +365,29 @@ func NewCluster() *Cluster {
c.Logs[0].MustInitialize() c.Logs[0].MustInitialize()
// Join second node. // Join second node.
c.Logs[1].MustOpen()
go func() { go func() {
c.Logs[0].MustWaitUncommitted(2) c.Logs[0].MustWaitUncommitted(2)
c.Logs[0].Clock.apply() c.Logs[0].Clock.apply()
c.Logs[0].Clock.heartbeat()
c.Logs[1].Clock.apply()
}() }()
c.Logs[1].MustOpen()
if err := c.Logs[1].Join(c.Logs[0].URL); err != nil { if err := c.Logs[1].Join(c.Logs[0].URL); err != nil {
panic("join: " + err.Error()) panic("join: " + err.Error())
} }
c.Logs[0].Clock.heartbeat()
c.Logs[1].MustWaitUncommitted(2)
c.Logs[1].Clock.apply()
c.Logs[0].Clock.heartbeat()
// Join third node. // Join third node.
c.Logs[2].MustOpen()
go func() { go func() {
c.Logs[0].MustWaitUncommitted(3) c.Logs[0].MustWaitUncommitted(3)
c.Logs[1].MustWaitUncommitted(3)
c.Logs[0].Clock.heartbeat() c.Logs[0].Clock.heartbeat()
c.Logs[0].Clock.apply() c.Logs[0].Clock.apply()
c.Logs[1].Clock.apply() c.Logs[1].Clock.apply()
c.Logs[2].Clock.apply() c.Logs[2].Clock.apply()
}() }()
c.Logs[2].MustOpen()
if err := c.Logs[2].Log.Join(c.Logs[0].Log.URL); err != nil { if err := c.Logs[2].Log.Join(c.Logs[0].Log.URL); err != nil {
panic("join: " + err.Error()) panic("join: " + err.Error())
} }
@ -378,7 +471,10 @@ func (l *Log) MustOpen() {
// MustInitialize initializes the log. Panic on error. // MustInitialize initializes the log. Panic on error.
func (l *Log) MustInitialize() { func (l *Log) MustInitialize() {
go func() { l.Clock.apply() }() go func() {
l.MustWaitUncommitted(1)
l.Clock.apply()
}()
if err := l.Initialize(); err != nil { if err := l.Initialize(); err != nil {
panic("initialize: " + err.Error()) panic("initialize: " + err.Error())
} }

View File

@ -28,6 +28,11 @@ func (t *HTTPTransport) Join(uri *url.URL, nodeURL *url.URL) (uint64, *Config, e
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return 0, nil, errors.New(s)
}
// Parse returned id. // Parse returned id.
idString := resp.Header.Get("X-Raft-ID") idString := resp.Header.Get("X-Raft-ID")
id, err := strconv.ParseUint(idString, 10, 64) id, err := strconv.ParseUint(idString, 10, 64)
@ -41,17 +46,29 @@ func (t *HTTPTransport) Join(uri *url.URL, nodeURL *url.URL) (uint64, *Config, e
return 0, nil, fmt.Errorf("config unmarshal: %s", err) return 0, nil, fmt.Errorf("config unmarshal: %s", err)
} }
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return 0, nil, errors.New(s)
}
return id, config, nil return id, config, nil
} }
// Leave removes a node from a cluster's membership. // Leave removes a node from a cluster's membership.
func (t *HTTPTransport) Leave(uri *url.URL, id uint64) error { func (t *HTTPTransport) Leave(uri *url.URL, id uint64) error {
return nil // TODO(benbjohnson) // Construct URL.
u := *uri
u.Path = path.Join(u.Path, "raft/leave")
u.RawQuery = (&url.Values{"id": {strconv.FormatUint(id, 10)}}).Encode()
// Send HTTP request.
resp, err := http.Get(u.String())
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
// Parse returned error.
if s := resp.Header.Get("X-Raft-Error"); s != "" {
return errors.New(s)
}
return nil
} }
// Heartbeat checks the status of a follower. // Heartbeat checks the status of a follower.

View File

@ -4,32 +4,145 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/url"
"sync"
"testing"
"time"
// "net/http"
// "net/http/httptest"
// "strings"
// "testing"
"github.com/influxdb/influxdb/raft"
)
/*
import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/influxdb/influxdb/raft" "github.com/influxdb/influxdb/raft"
) )
// Ensure a join over HTTP can be read and responded to.
func TestHTTPTransport_Join(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if path := r.URL.Path; path != `/raft/join` {
t.Fatalf("unexpected path: %q", path)
}
if s := r.FormValue("url"); s != `//local` {
t.Fatalf("unexpected term: %q", s)
}
w.Header().Set("X-Raft-ID", "1")
w.Write([]byte(`{}`))
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
id, config, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err != nil {
t.Fatalf("unexpected error: %s", err)
} else if id != 1 {
t.Fatalf("unexpected id: %d", id)
} else if config == nil {
t.Fatalf("unexpected config")
}
}
// Ensure that joining a server that doesn't exist returns an error.
func TestHTTPTransport_Join_ErrConnectionRefused(t *testing.T) {
_, _, err := (&raft.HTTPTransport{}).Join(&url.URL{Scheme: "http", Host: "localhost:27322"}, &url.URL{Host: "local"})
if err == nil || !strings.Contains(err.Error(), "connection refused") {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure the response from a join contains a valid id.
func TestHTTPTransport_Join_ErrInvalidID(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-ID", "xxx")
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `invalid id: "xxx"` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure the response from a join contains a valid config.
func TestHTTPTransport_Join_ErrInvalidConfig(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-ID", "1")
w.Write([]byte(`{`))
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `config unmarshal: unexpected EOF` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure the errors returned from a join are passed through.
func TestHTTPTransport_Join_Err(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-ID", "")
w.Header().Set("X-Raft-Error", "oh no")
}))
defer s.Close()
// Execute join against test server.
u, _ := url.Parse(s.URL)
_, _, err := (&raft.HTTPTransport{}).Join(u, &url.URL{Host: "local"})
if err == nil || err.Error() != `oh no` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a leave over HTTP can be read and responded to.
func TestHTTPTransport_Leave(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if path := r.URL.Path; path != `/raft/leave` {
t.Fatalf("unexpected path: %q", path)
} else if id := r.FormValue("id"); id != `1` {
t.Fatalf("unexpected id: %q", id)
}
}))
defer s.Close()
// Execute leave against test server.
u, _ := url.Parse(s.URL)
if err := (&raft.HTTPTransport{}).Leave(u, 1); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure that leaving a server that doesn't exist returns an error.
func TestHTTPTransport_Leave_ErrConnectionRefused(t *testing.T) {
err := (&raft.HTTPTransport{}).Leave(&url.URL{Scheme: "http", Host: "localhost:27322"}, 1)
if err == nil || !strings.Contains(err.Error(), "connection refused") {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure the errors returned from a leave are passed through.
func TestHTTPTransport_Leave_Err(t *testing.T) {
// Start mock HTTP server.
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Raft-Error", "oh no")
}))
defer s.Close()
// Execute leave against test server.
u, _ := url.Parse(s.URL)
err := (&raft.HTTPTransport{}).Leave(u, 1)
if err == nil || err.Error() != `oh no` {
t.Fatalf("unexpected error: %s", err)
}
}
// Ensure a heartbeat over HTTP can be read and responded to. // Ensure a heartbeat over HTTP can be read and responded to.
func TestHTTPTransport_Heartbeat(t *testing.T) { func TestHTTPTransport_Heartbeat(t *testing.T) {
@ -55,7 +168,7 @@ func TestHTTPTransport_Heartbeat(t *testing.T) {
// Execute heartbeat against test server. // Execute heartbeat against test server.
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
newIndex, newTerm, err := raft.DefaultTransport.Heartbeat(u, 1, 2, 3) newIndex, newTerm, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} else if newIndex != 4 { } else if newIndex != 4 {
@ -87,7 +200,7 @@ func TestHTTPTransport_Heartbeat_Err(t *testing.T) {
})) }))
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
_, _, err := raft.DefaultTransport.Heartbeat(u, 1, 2, 3) _, _, err := (&raft.HTTPTransport{}).Heartbeat(u, 1, 2, 3)
if err == nil { if err == nil {
t.Errorf("%d. expected error", i) t.Errorf("%d. expected error", i)
} else if tt.err != err.Error() { } else if tt.err != err.Error() {
@ -100,7 +213,7 @@ func TestHTTPTransport_Heartbeat_Err(t *testing.T) {
// Ensure an HTTP heartbeat to a stopped server returns an error. // Ensure an HTTP heartbeat to a stopped server returns an error.
func TestHTTPTransport_Heartbeat_ErrConnectionRefused(t *testing.T) { func TestHTTPTransport_Heartbeat_ErrConnectionRefused(t *testing.T) {
u, _ := url.Parse("http://localhost:41932") u, _ := url.Parse("http://localhost:41932")
_, _, err := raft.DefaultTransport.Heartbeat(u, 0, 0, 0) _, _, err := (&raft.HTTPTransport{}).Heartbeat(u, 0, 0, 0)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} else if !strings.Contains(err.Error(), `connection refused`) { } else if !strings.Contains(err.Error(), `connection refused`) {
@ -130,7 +243,7 @@ func TestHTTPTransport_ReadFrom(t *testing.T) {
// Execute stream against test server. // Execute stream against test server.
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
r, err := raft.DefaultTransport.ReadFrom(u, 1, 2, 3) r, err := (&raft.HTTPTransport{}).ReadFrom(u, 1, 2, 3)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} }
@ -150,7 +263,7 @@ func TestHTTPTransport_ReadFrom_Err(t *testing.T) {
// Execute stream against test server. // Execute stream against test server.
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
r, err := raft.DefaultTransport.ReadFrom(u, 0, 0, 0) r, err := (&raft.HTTPTransport{}).ReadFrom(u, 0, 0, 0)
if err == nil { if err == nil {
t.Fatalf("expected error") t.Fatalf("expected error")
} else if err.Error() != `bad stream` { } else if err.Error() != `bad stream` {
@ -163,7 +276,7 @@ func TestHTTPTransport_ReadFrom_Err(t *testing.T) {
// Ensure an streaming over HTTP to a stopped server returns an error. // Ensure an streaming over HTTP to a stopped server returns an error.
func TestHTTPTransport_ReadFrom_ErrConnectionRefused(t *testing.T) { func TestHTTPTransport_ReadFrom_ErrConnectionRefused(t *testing.T) {
u, _ := url.Parse("http://localhost:41932") u, _ := url.Parse("http://localhost:41932")
_, err := raft.DefaultTransport.ReadFrom(u, 0, 0, 0) _, err := (&raft.HTTPTransport{}).ReadFrom(u, 0, 0, 0)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} else if !strings.Contains(err.Error(), `connection refused`) { } else if !strings.Contains(err.Error(), `connection refused`) {
@ -197,7 +310,7 @@ func TestHTTPTransport_RequestVote(t *testing.T) {
// Execute heartbeat against test server. // Execute heartbeat against test server.
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
newTerm, err := raft.DefaultTransport.RequestVote(u, 1, 2, 3, 4) newTerm, err := (&raft.HTTPTransport{}).RequestVote(u, 1, 2, 3, 4)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} else if newTerm != 5 { } else if newTerm != 5 {
@ -214,7 +327,7 @@ func TestHTTPTransport_RequestVote_ErrInvalidTerm(t *testing.T) {
defer s.Close() defer s.Close()
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
_, err := raft.DefaultTransport.RequestVote(u, 0, 0, 0, 0) _, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil { if err == nil {
t.Errorf("expected error") t.Errorf("expected error")
} else if err.Error() != `invalid term: "xxx"` { } else if err.Error() != `invalid term: "xxx"` {
@ -232,7 +345,7 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) {
defer s.Close() defer s.Close()
u, _ := url.Parse(s.URL) u, _ := url.Parse(s.URL)
_, err := raft.DefaultTransport.RequestVote(u, 0, 0, 0, 0) _, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil { if err == nil {
t.Errorf("expected error") t.Errorf("expected error")
} else if err.Error() != `already voted` { } else if err.Error() != `already voted` {
@ -243,14 +356,13 @@ func TestHTTPTransport_RequestVote_Error(t *testing.T) {
// Ensure that requesting a vote over HTTP to a stopped server returns an error. // Ensure that requesting a vote over HTTP to a stopped server returns an error.
func TestHTTPTransport_RequestVote_ErrConnectionRefused(t *testing.T) { func TestHTTPTransport_RequestVote_ErrConnectionRefused(t *testing.T) {
u, _ := url.Parse("http://localhost:41932") u, _ := url.Parse("http://localhost:41932")
_, err := raft.DefaultTransport.RequestVote(u, 0, 0, 0, 0) _, err := (&raft.HTTPTransport{}).RequestVote(u, 0, 0, 0, 0)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} else if !strings.Contains(err.Error(), `connection refused`) { } else if !strings.Contains(err.Error(), `connection refused`) {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} }
} }
*/
// Transport represents a test transport that directly calls another log. // Transport represents a test transport that directly calls another log.
// Logs are looked up by hostname only. // Logs are looked up by hostname only.
@ -313,7 +425,7 @@ func (t *Transport) ReadFrom(u *url.URL, id, term, index uint64) (io.ReadCloser,
// Create a streaming buffer that will hang until Close() is called. // Create a streaming buffer that will hang until Close() is called.
buf := newStreamingBuffer() buf := newStreamingBuffer()
go func() { go func() {
if err := l.WriteEntriesTo(buf.buf, id, term, index); err != nil { if err := l.WriteEntriesTo(buf, id, term, index); err != nil {
warnf("Transport.ReadFrom: error: %s", err) warnf("Transport.ReadFrom: error: %s", err)
} }
_ = buf.Close() _ = buf.Close()
@ -360,7 +472,10 @@ func (b *streamingBuffer) Closed() bool {
func (b *streamingBuffer) Read(p []byte) (n int, err error) { func (b *streamingBuffer) Read(p []byte) (n int, err error) {
for { for {
b.mu.Lock()
n, err = b.buf.Read(p) n, err = b.buf.Read(p)
b.mu.Unlock()
if err == io.EOF && n > 0 { // hit EOF, read data if err == io.EOF && n > 0 { // hit EOF, read data
return n, nil return n, nil
} else if err == io.EOF { // hit EOF, no data } else if err == io.EOF { // hit EOF, no data
@ -379,6 +494,12 @@ func (b *streamingBuffer) Read(p []byte) (n int, err error) {
} }
} }
func (b *streamingBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
// Ensure the streaming buffer will continue to stream data, if available, after it's closed. // Ensure the streaming buffer will continue to stream data, if available, after it's closed.
// This is primarily a santity check to make sure our test buffer isn't causing problems. // This is primarily a santity check to make sure our test buffer isn't causing problems.
func TestStreamingBuffer(t *testing.T) { func TestStreamingBuffer(t *testing.T) {