Merge remote-tracking branch 'raft/master'

pull/820/head
John Shahid 2014-08-12 14:21:17 -04:00
commit 57d29e8bed
10 changed files with 262 additions and 83 deletions

17
_vendor/raft/Makefile Normal file
View File

@ -0,0 +1,17 @@
COVERPROFILE=cover.out
default: test
cover:
go test -coverprofile=$(COVERPROFILE) .
go tool cover -html=$(COVERPROFILE)
rm $(COVERPROFILE)
dependencies:
go get -d .
test:
go test -i ./...
go test -v ./...
.PHONY: coverage dependencies test

View File

@ -56,7 +56,9 @@ func newCommand(name string, data []byte) (Command, error) {
return nil, err return nil, err
} }
} else { } else {
json.NewDecoder(bytes.NewReader(data)).Decode(copy) if err := json.NewDecoder(bytes.NewReader(data)).Decode(copy); err != nil {
return nil, err
}
} }
} }

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"time"
) )
// Parts from this transporter were heavily influenced by Peter Bougon's // Parts from this transporter were heavily influenced by Peter Bougon's
@ -42,7 +43,7 @@ type HTTPMuxer interface {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Creates a new HTTP transporter with the given path prefix. // Creates a new HTTP transporter with the given path prefix.
func NewHTTPTransporter(prefix string) *HTTPTransporter { func NewHTTPTransporter(prefix string, timeout time.Duration) *HTTPTransporter {
t := &HTTPTransporter{ t := &HTTPTransporter{
DisableKeepAlives: false, DisableKeepAlives: false,
prefix: prefix, prefix: prefix,
@ -53,6 +54,7 @@ func NewHTTPTransporter(prefix string) *HTTPTransporter {
Transport: &http.Transport{DisableKeepAlives: false}, Transport: &http.Transport{DisableKeepAlives: false},
} }
t.httpClient.Transport = t.Transport t.httpClient.Transport = t.Transport
t.Transport.ResponseHeaderTimeout = timeout
return t return t
} }
@ -120,7 +122,6 @@ func (t *HTTPTransporter) SendAppendEntriesRequest(server Server, peer *Peer, re
url := joinPath(peer.ConnectionString, t.AppendEntriesPath()) url := joinPath(peer.ConnectionString, t.AppendEntriesPath())
traceln(server.Name(), "POST", url) traceln(server.Name(), "POST", url)
t.Transport.ResponseHeaderTimeout = server.ElectionTimeout()
httpResp, err := t.httpClient.Post(url, "application/protobuf", &b) httpResp, err := t.httpClient.Post(url, "application/protobuf", &b)
if httpResp == nil || err != nil { if httpResp == nil || err != nil {
traceln("transporter.ae.response.error:", err) traceln("transporter.ae.response.error:", err)
@ -243,6 +244,10 @@ func (t *HTTPTransporter) appendEntriesHandler(server Server) http.HandlerFunc {
} }
resp := server.AppendEntries(req) resp := server.AppendEntries(req)
if resp == nil {
http.Error(w, "Failed creating response.", http.StatusInternalServerError)
return
}
if _, err := resp.Encode(w); err != nil { if _, err := resp.Encode(w); err != nil {
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
@ -262,6 +267,10 @@ func (t *HTTPTransporter) requestVoteHandler(server Server) http.HandlerFunc {
} }
resp := server.RequestVote(req) resp := server.RequestVote(req)
if resp == nil {
http.Error(w, "Failed creating response.", http.StatusInternalServerError)
return
}
if _, err := resp.Encode(w); err != nil { if _, err := resp.Encode(w); err != nil {
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
@ -281,6 +290,10 @@ func (t *HTTPTransporter) snapshotHandler(server Server) http.HandlerFunc {
} }
resp := server.RequestSnapshot(req) resp := server.RequestSnapshot(req)
if resp == nil {
http.Error(w, "Failed creating response.", http.StatusInternalServerError)
return
}
if _, err := resp.Encode(w); err != nil { if _, err := resp.Encode(w); err != nil {
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
@ -300,6 +313,10 @@ func (t *HTTPTransporter) snapshotRecoveryHandler(server Server) http.HandlerFun
} }
resp := server.SnapshotRecoveryRequest(req) resp := server.SnapshotRecoveryRequest(req)
if resp == nil {
http.Error(w, "Failed creating response.", http.StatusInternalServerError)
return
}
if _, err := resp.Encode(w); err != nil { if _, err := resp.Encode(w); err != nil {
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return

View File

@ -11,7 +11,7 @@ import (
// Ensure that we can start several servers and have them communicate. // Ensure that we can start several servers and have them communicate.
func TestHTTPTransporter(t *testing.T) { func TestHTTPTransporter(t *testing.T) {
transporter := NewHTTPTransporter("/raft") transporter := NewHTTPTransporter("/raft", testElectionTimeout)
transporter.DisableKeepAlives = true transporter.DisableKeepAlives = true
servers := []Server{} servers := []Server{}
@ -91,7 +91,7 @@ func runTestHttpServers(t *testing.T, servers *[]Server, transporter *HTTPTransp
func BenchmarkSpeed(b *testing.B) { func BenchmarkSpeed(b *testing.B) {
transporter := NewHTTPTransporter("/raft") transporter := NewHTTPTransporter("/raft", testElectionTimeout)
transporter.DisableKeepAlives = true transporter.DisableKeepAlives = true
servers := []Server{} servers := []Server{}

View File

@ -27,6 +27,7 @@ type Log struct {
mutex sync.RWMutex mutex sync.RWMutex
startIndex uint64 // the index before the first entry in the Log entries startIndex uint64 // the index before the first entry in the Log entries
startTerm uint64 startTerm uint64
initialized bool
} }
// The results of the applying a log entry. // The results of the applying a log entry.
@ -147,7 +148,9 @@ func (l *Log) open(path string) error {
if os.IsNotExist(err) { if os.IsNotExist(err) {
l.file, err = os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600) l.file, err = os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600)
debugln("log.open.create ", path) debugln("log.open.create ", path)
if err == nil {
l.initialized = true
}
return err return err
} }
return err return err
@ -187,6 +190,7 @@ func (l *Log) open(path string) error {
readBytes += int64(n) readBytes += int64(n)
} }
debugln("open.log.recovery number of log ", len(l.entries)) debugln("open.log.recovery number of log ", len(l.entries))
l.initialized = true
return nil return nil
} }
@ -375,6 +379,15 @@ func (l *Log) setCommitIndex(index uint64) error {
entry.event.returnValue = returnValue entry.event.returnValue = returnValue
entry.event.c <- err entry.event.c <- err
} }
_, isJoinCommand := command.(JoinCommand)
// we can only commit up to the most recent join command
// if there is a join in this batch of commands.
// after this commit, we need to recalculate the majority.
if isJoinCommand {
return nil
}
} }
return nil return nil
} }

View File

@ -29,7 +29,9 @@ func newLogEntry(log *Log, event *ev, index uint64, term uint64, command Command
return nil, err return nil, err
} }
} else { } else {
json.NewEncoder(&buf).Encode(command) if err := json.NewEncoder(&buf).Encode(command); err != nil {
return nil, err
}
} }
} }

View File

@ -17,10 +17,10 @@ type Peer struct {
Name string `json:"name"` Name string `json:"name"`
ConnectionString string `json:"connectionString"` ConnectionString string `json:"connectionString"`
prevLogIndex uint64 prevLogIndex uint64
mutex sync.RWMutex
stopChan chan bool stopChan chan bool
heartbeatInterval time.Duration heartbeatInterval time.Duration
lastActivity time.Time lastActivity time.Time
sync.RWMutex
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -56,18 +56,24 @@ func (p *Peer) setHeartbeatInterval(duration time.Duration) {
// Retrieves the previous log index. // Retrieves the previous log index.
func (p *Peer) getPrevLogIndex() uint64 { func (p *Peer) getPrevLogIndex() uint64 {
p.mutex.RLock() p.RLock()
defer p.mutex.RUnlock() defer p.RUnlock()
return p.prevLogIndex return p.prevLogIndex
} }
// Sets the previous log index. // Sets the previous log index.
func (p *Peer) setPrevLogIndex(value uint64) { func (p *Peer) setPrevLogIndex(value uint64) {
p.mutex.Lock() p.Lock()
defer p.mutex.Unlock() defer p.Unlock()
p.prevLogIndex = value p.prevLogIndex = value
} }
func (p *Peer) setLastActivity(now time.Time) {
p.Lock()
defer p.Unlock()
p.lastActivity = now
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// //
// Methods // Methods
@ -82,17 +88,28 @@ func (p *Peer) setPrevLogIndex(value uint64) {
func (p *Peer) startHeartbeat() { func (p *Peer) startHeartbeat() {
p.stopChan = make(chan bool) p.stopChan = make(chan bool)
c := make(chan bool) c := make(chan bool)
go p.heartbeat(c)
p.setLastActivity(time.Now())
p.server.routineGroup.Add(1)
go func() {
defer p.server.routineGroup.Done()
p.heartbeat(c)
}()
<-c <-c
} }
// Stops the peer heartbeat. // Stops the peer heartbeat.
func (p *Peer) stopHeartbeat(flush bool) { func (p *Peer) stopHeartbeat(flush bool) {
p.setLastActivity(time.Time{})
p.stopChan <- flush p.stopChan <- flush
} }
// LastActivity returns the last time any response was received from the peer. // LastActivity returns the last time any response was received from the peer.
func (p *Peer) LastActivity() time.Time { func (p *Peer) LastActivity() time.Time {
p.RLock()
defer p.RUnlock()
return p.lastActivity return p.lastActivity
} }
@ -103,8 +120,8 @@ func (p *Peer) LastActivity() time.Time {
// Clones the state of the peer. The clone is not attached to a server and // Clones the state of the peer. The clone is not attached to a server and
// the heartbeat timer will not exist. // the heartbeat timer will not exist.
func (p *Peer) clone() *Peer { func (p *Peer) clone() *Peer {
p.mutex.Lock() p.Lock()
defer p.mutex.Unlock() defer p.Unlock()
return &Peer{ return &Peer{
Name: p.Name, Name: p.Name,
ConnectionString: p.ConnectionString, ConnectionString: p.ConnectionString,
@ -181,9 +198,9 @@ func (p *Peer) sendAppendEntriesRequest(req *AppendEntriesRequest) {
} }
traceln("peer.append.resp: ", p.server.Name(), "<-", p.Name) traceln("peer.append.resp: ", p.server.Name(), "<-", p.Name)
p.setLastActivity(time.Now())
// If successful then update the previous log index. // If successful then update the previous log index.
p.mutex.Lock() p.Lock()
p.lastActivity = time.Now()
if resp.Success() { if resp.Success() {
if len(req.Entries) > 0 { if len(req.Entries) > 0 {
p.prevLogIndex = req.Entries[len(req.Entries)-1].GetIndex() p.prevLogIndex = req.Entries[len(req.Entries)-1].GetIndex()
@ -229,7 +246,7 @@ func (p *Peer) sendAppendEntriesRequest(req *AppendEntriesRequest) {
debugln("peer.append.resp.decrement: ", p.Name, "; idx =", p.prevLogIndex) debugln("peer.append.resp.decrement: ", p.Name, "; idx =", p.prevLogIndex)
} }
} }
p.mutex.Unlock() p.Unlock()
// Attach the peer to resp, thus server can know where it comes from // Attach the peer to resp, thus server can know where it comes from
resp.peer = p.Name resp.peer = p.Name
@ -251,7 +268,8 @@ func (p *Peer) sendSnapshotRequest(req *SnapshotRequest) {
// If successful, the peer should have been to snapshot state // If successful, the peer should have been to snapshot state
// Send it the snapshot! // Send it the snapshot!
p.lastActivity = time.Now() p.setLastActivity(time.Now())
if resp.Success { if resp.Success {
p.sendSnapshotRecoveryRequest() p.sendSnapshotRecoveryRequest()
} else { } else {
@ -272,7 +290,7 @@ func (p *Peer) sendSnapshotRecoveryRequest() {
return return
} }
p.lastActivity = time.Now() p.setLastActivity(time.Now())
if resp.Success { if resp.Success {
p.prevLogIndex = req.LastIndex p.prevLogIndex = req.LastIndex
} else { } else {
@ -293,7 +311,7 @@ func (p *Peer) sendVoteRequest(req *RequestVoteRequest, c chan *RequestVoteRespo
req.peer = p req.peer = p
if resp := p.server.Transporter().SendVoteRequest(p.server, p, req); resp != nil { if resp := p.server.Transporter().SendVoteRequest(p.server, p, req); resp != nil {
debugln("peer.vote.recv: ", p.server.Name(), "<-", p.Name) debugln("peer.vote.recv: ", p.server.Name(), "<-", p.Name)
p.lastActivity = time.Now() p.setLastActivity(time.Now())
resp.peer = p resp.peer = p
c <- resp c <- resp
} else { } else {

View File

@ -21,6 +21,7 @@ import (
const ( const (
Stopped = "stopped" Stopped = "stopped"
Initialized = "initialized"
Follower = "follower" Follower = "follower"
Candidate = "candidate" Candidate = "candidate"
Leader = "leader" Leader = "leader"
@ -54,6 +55,7 @@ const ElectionTimeoutThresholdPercent = 0.8
var NotLeaderError = errors.New("raft.Server: Not current leader") var NotLeaderError = errors.New("raft.Server: Not current leader")
var DuplicatePeerError = errors.New("raft.Server: Duplicate peer") var DuplicatePeerError = errors.New("raft.Server: Duplicate peer")
var CommandTimeoutError = errors.New("raft: Command timeout") var CommandTimeoutError = errors.New("raft: Command timeout")
var StopError = errors.New("raft: Has been stopped")
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// //
@ -94,6 +96,7 @@ type Server interface {
AddPeer(name string, connectiongString string) error AddPeer(name string, connectiongString string) error
RemovePeer(name string) error RemovePeer(name string) error
Peers() map[string]*Peer Peers() map[string]*Peer
Init() error
Start() error Start() error
Stop() Stop()
Running() bool Running() bool
@ -121,7 +124,7 @@ type server struct {
mutex sync.RWMutex mutex sync.RWMutex
syncedPeer map[string]bool syncedPeer map[string]bool
stopped chan chan bool stopped chan bool
c chan *ev c chan *ev
electionTimeout time.Duration electionTimeout time.Duration
heartbeatInterval time.Duration heartbeatInterval time.Duration
@ -138,6 +141,8 @@ type server struct {
maxLogEntriesPerRequest uint64 maxLogEntriesPerRequest uint64
connectionString string connectionString string
routineGroup sync.WaitGroup
} }
// An internal event to be processed by the server's event loop. // An internal event to be processed by the server's event loop.
@ -175,7 +180,6 @@ func NewServer(name string, path string, transporter Transporter, stateMachine S
state: Stopped, state: Stopped,
peers: make(map[string]*Peer), peers: make(map[string]*Peer),
log: newLog(), log: newLog(),
stopped: make(chan chan bool),
c: make(chan *ev, 256), c: make(chan *ev, 256),
electionTimeout: DefaultElectionTimeout, electionTimeout: DefaultElectionTimeout,
heartbeatInterval: DefaultHeartbeatInterval, heartbeatInterval: DefaultHeartbeatInterval,
@ -330,6 +334,8 @@ func (s *server) IsLogEmpty() bool {
// A list of all the log entries. This should only be used for debugging purposes. // A list of all the log entries. This should only be used for debugging purposes.
func (s *server) LogEntries() []*LogEntry { func (s *server) LogEntries() []*LogEntry {
s.log.mutex.RLock()
defer s.log.mutex.RUnlock()
return s.log.entries return s.log.entries
} }
@ -356,8 +362,8 @@ func (s *server) promotable() bool {
// Retrieves the number of member servers in the consensus. // Retrieves the number of member servers in the consensus.
func (s *server) MemberCount() int { func (s *server) MemberCount() int {
s.mutex.Lock() s.mutex.RLock()
defer s.mutex.Unlock() defer s.mutex.RUnlock()
return len(s.peers) + 1 return len(s.peers) + 1
} }
@ -423,35 +429,24 @@ func init() {
RegisterCommand(&DefaultLeaveCommand{}) RegisterCommand(&DefaultLeaveCommand{})
} }
// Start as follow // Start the raft server
// If log entries exist then allow promotion to candidate if no AEs received. // If log entries exist then allow promotion to candidate if no AEs received.
// If no log entries exist then wait for AEs from another node. // If no log entries exist then wait for AEs from another node.
// If no log entries exist and a self-join command is issued then // If no log entries exist and a self-join command is issued then
// immediately become leader and commit entry. // immediately become leader and commit entry.
func (s *server) Start() error { func (s *server) Start() error {
// Exit if the server is already running. // Exit if the server is already running.
if s.State() != Stopped { if s.Running() {
return errors.New("raft.Server: Server already running") return fmt.Errorf("raft.Server: Server already running[%v]", s.state)
} }
// Create snapshot directory if not exist if err := s.Init(); err != nil {
os.Mkdir(path.Join(s.path, "snapshot"), 0700) return err
if err := s.readConf(); err != nil {
s.debugln("raft: Conf file error: ", err)
return fmt.Errorf("raft: Initialization error: %s", err)
} }
// Initialize the log and load it up. // stopped needs to be allocated each time server starts
if err := s.log.open(s.LogPath()); err != nil { // because it is closed at `Stop`.
s.debugln("raft: Log error: ", err) s.stopped = make(chan bool)
return fmt.Errorf("raft: Initialization error: %s", err)
}
// Update the term to the last term in the log.
_, s.currentTerm = s.log.lastInfo()
s.setState(Follower) s.setState(Follower)
// If no log entries exist then // If no log entries exist then
@ -469,27 +464,76 @@ func (s *server) Start() error {
debugln(s.GetState()) debugln(s.GetState())
go s.loop() s.routineGroup.Add(1)
go func() {
defer s.routineGroup.Done()
s.loop()
}()
return nil return nil
} }
// Init initializes the raft server.
// If there is no previous log file under the given path, Init() will create an empty log file.
// Otherwise, Init() will load in the log entries from the log file.
func (s *server) Init() error {
if s.Running() {
return fmt.Errorf("raft.Server: Server already running[%v]", s.state)
}
// Server has been initialized or server was stopped after initialized
// If log has been initialized, we know that the server was stopped after
// running.
if s.state == Initialized || s.log.initialized {
s.state = Initialized
return nil
}
// Create snapshot directory if it does not exist
err := os.Mkdir(path.Join(s.path, "snapshot"), 0700)
if err != nil && !os.IsExist(err) {
s.debugln("raft: Snapshot dir error: ", err)
return fmt.Errorf("raft: Initialization error: %s", err)
}
if err := s.readConf(); err != nil {
s.debugln("raft: Conf file error: ", err)
return fmt.Errorf("raft: Initialization error: %s", err)
}
// Initialize the log and load it up.
if err := s.log.open(s.LogPath()); err != nil {
s.debugln("raft: Log error: ", err)
return fmt.Errorf("raft: Initialization error: %s", err)
}
// Update the term to the last term in the log.
_, s.currentTerm = s.log.lastInfo()
s.state = Initialized
return nil
}
// Shuts down the server. // Shuts down the server.
func (s *server) Stop() { func (s *server) Stop() {
stop := make(chan bool) if s.State() == Stopped {
s.stopped <- stop return
s.state = Stopped }
close(s.stopped)
// make sure all goroutines have stopped before we close the log
s.routineGroup.Wait()
// make sure the server has stopped before we close the log
<-stop
s.log.close() s.log.close()
s.setState(Stopped)
} }
// Checks if the server is currently running. // Checks if the server is currently running.
func (s *server) Running() bool { func (s *server) Running() bool {
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
return s.state != Stopped return (s.state != Stopped && s.state != Initialized)
} }
//-------------------------------------- //--------------------------------------
@ -502,8 +546,6 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) {
_assert(term > s.currentTerm, _assert(term > s.currentTerm,
"upadteCurrentTerm: update is called when term is not larger than currentTerm") "upadteCurrentTerm: update is called when term is not larger than currentTerm")
s.mutex.Lock()
defer s.mutex.Unlock()
// Store previous values temporarily. // Store previous values temporarily.
prevTerm := s.currentTerm prevTerm := s.currentTerm
prevLeader := s.leader prevLeader := s.leader
@ -511,21 +553,20 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) {
// set currentTerm = T, convert to follower (§5.1) // set currentTerm = T, convert to follower (§5.1)
// stop heartbeats before step-down // stop heartbeats before step-down
if s.state == Leader { if s.state == Leader {
s.mutex.Unlock()
for _, peer := range s.peers { for _, peer := range s.peers {
peer.stopHeartbeat(false) peer.stopHeartbeat(false)
} }
s.mutex.Lock()
} }
// update the term and clear vote for // update the term and clear vote for
if s.state != Follower { if s.state != Follower {
s.mutex.Unlock()
s.setState(Follower) s.setState(Follower)
s.mutex.Lock()
} }
s.mutex.Lock()
s.currentTerm = term s.currentTerm = term
s.leader = leaderName s.leader = leaderName
s.votedFor = "" s.votedFor = ""
s.mutex.Unlock()
// Dispatch change events. // Dispatch change events.
s.DispatchEvent(newEvent(TermChangeEventType, s.currentTerm, prevTerm)) s.DispatchEvent(newEvent(TermChangeEventType, s.currentTerm, prevTerm))
@ -555,9 +596,9 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) {
func (s *server) loop() { func (s *server) loop() {
defer s.debugln("server.loop.end") defer s.debugln("server.loop.end")
for s.state != Stopped { state := s.State()
state := s.State()
for state != Stopped {
s.debugln("server.loop.run ", state) s.debugln("server.loop.run ", state)
switch state { switch state {
case Follower: case Follower:
@ -569,19 +610,36 @@ func (s *server) loop() {
case Snapshotting: case Snapshotting:
s.snapshotLoop() s.snapshotLoop()
} }
state = s.State()
} }
} }
// Sends an event to the event loop to be processed. The function will wait // Sends an event to the event loop to be processed. The function will wait
// until the event is actually processed before returning. // until the event is actually processed before returning.
func (s *server) send(value interface{}) (interface{}, error) { func (s *server) send(value interface{}) (interface{}, error) {
if !s.Running() {
return nil, StopError
}
event := &ev{target: value, c: make(chan error, 1)} event := &ev{target: value, c: make(chan error, 1)}
s.c <- event select {
err := <-event.c case s.c <- event:
return event.returnValue, err case <-s.stopped:
return nil, StopError
}
select {
case <-s.stopped:
return nil, StopError
case err := <-event.c:
return event.returnValue, err
}
} }
func (s *server) sendAsync(value interface{}) { func (s *server) sendAsync(value interface{}) {
if !s.Running() {
return
}
event := &ev{target: value, c: make(chan error, 1)} event := &ev{target: value, c: make(chan error, 1)}
// try a non-blocking send first // try a non-blocking send first
// in most cases, this should not be blocking // in most cases, this should not be blocking
@ -592,8 +650,13 @@ func (s *server) sendAsync(value interface{}) {
default: default:
} }
s.routineGroup.Add(1)
go func() { go func() {
s.c <- event defer s.routineGroup.Done()
select {
case s.c <- event:
case <-s.stopped:
}
}() }()
} }
@ -611,9 +674,8 @@ func (s *server) followerLoop() {
var err error var err error
update := false update := false
select { select {
case stop := <-s.stopped: case <-s.stopped:
s.setState(Stopped) s.setState(Stopped)
stop <- true
return return
case e := <-s.c: case e := <-s.c:
@ -688,7 +750,11 @@ func (s *server) candidateLoop() {
// Send RequestVote RPCs to all other servers. // Send RequestVote RPCs to all other servers.
respChan = make(chan *RequestVoteResponse, len(s.peers)) respChan = make(chan *RequestVoteResponse, len(s.peers))
for _, peer := range s.peers { for _, peer := range s.peers {
go peer.sendVoteRequest(newRequestVoteRequest(s.currentTerm, s.name, lastLogIndex, lastLogTerm), respChan) s.routineGroup.Add(1)
go func(peer *Peer) {
defer s.routineGroup.Done()
peer.sendVoteRequest(newRequestVoteRequest(s.currentTerm, s.name, lastLogIndex, lastLogTerm), respChan)
}(peer)
} }
// Wait for either: // Wait for either:
@ -711,9 +777,8 @@ func (s *server) candidateLoop() {
// Collect votes from peers. // Collect votes from peers.
select { select {
case stop := <-s.stopped: case <-s.stopped:
s.setState(Stopped) s.setState(Stopped)
stop <- true
return return
case resp := <-respChan: case resp := <-respChan:
@ -757,19 +822,22 @@ func (s *server) leaderLoop() {
// "Upon election: send initial empty AppendEntries RPCs (heartbeat) to // "Upon election: send initial empty AppendEntries RPCs (heartbeat) to
// each server; repeat during idle periods to prevent election timeouts // each server; repeat during idle periods to prevent election timeouts
// (§5.2)". The heartbeats started above do the "idle" period work. // (§5.2)". The heartbeats started above do the "idle" period work.
go s.Do(NOPCommand{}) s.routineGroup.Add(1)
go func() {
defer s.routineGroup.Done()
s.Do(NOPCommand{})
}()
// Begin to collect response from followers // Begin to collect response from followers
for s.State() == Leader { for s.State() == Leader {
var err error var err error
select { select {
case stop := <-s.stopped: case <-s.stopped:
// Stop all peers before stop // Stop all peers before stop
for _, peer := range s.peers { for _, peer := range s.peers {
peer.stopHeartbeat(false) peer.stopHeartbeat(false)
} }
s.setState(Stopped) s.setState(Stopped)
stop <- true
return return
case e := <-s.c: case e := <-s.c:
@ -797,9 +865,8 @@ func (s *server) snapshotLoop() {
for s.State() == Snapshotting { for s.State() == Snapshotting {
var err error var err error
select { select {
case stop := <-s.stopped: case <-s.stopped:
s.setState(Stopped) s.setState(Stopped)
stop <- true
return return
case e := <-s.c: case e := <-s.c:
@ -878,9 +945,14 @@ func (s *server) processAppendEntriesRequest(req *AppendEntriesRequest) (*Append
} }
if req.Term == s.currentTerm { if req.Term == s.currentTerm {
_assert(s.state != Leader, "leader.elected.at.same.term.%d\n", s.currentTerm) _assert(s.State() != Leader, "leader.elected.at.same.term.%d\n", s.currentTerm)
// change state to follower
s.state = Follower // step-down to follower when it is a candidate
if s.state == Candidate {
// change state to follower
s.setState(Follower)
}
// discover new leader when candidate // discover new leader when candidate
// save leader name when follower // save leader name when follower
s.leader = req.LeaderName s.leader = req.LeaderName
@ -1080,7 +1152,11 @@ func (s *server) RemovePeer(name string) error {
// So we might be holding log lock and waiting for log lock, // So we might be holding log lock and waiting for log lock,
// which lead to a deadlock. // which lead to a deadlock.
// TODO(xiangli) refactor log lock // TODO(xiangli) refactor log lock
go peer.stopHeartbeat(true) s.routineGroup.Add(1)
go func() {
defer s.routineGroup.Done()
peer.stopHeartbeat(true)
}()
} }
delete(s.peers, name) delete(s.peers, name)

View File

@ -2,6 +2,7 @@ package raft
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -26,11 +27,43 @@ func TestSnapshot(t *testing.T) {
// Restart server. // Restart server.
s.Stop() s.Stop()
s.Start()
// Recover from snapshot. // Recover from snapshot.
err = s.LoadSnapshot() err = s.LoadSnapshot()
assert.NoError(t, err) assert.NoError(t, err)
s.Start()
})
}
// Ensure that a new server can recover from previous snapshot with log
func TestSnapshotRecovery(t *testing.T) {
runServerWithMockStateMachine(Leader, func(s Server, m *mock.Mock) {
m.On("Save").Return([]byte("foo"), nil)
m.On("Recovery", []byte("foo")).Return(nil)
s.Do(&testCommand1{})
err := s.TakeSnapshot()
assert.NoError(t, err)
assert.Equal(t, s.(*server).snapshot.LastIndex, uint64(2))
// Repeat to make sure new snapshot gets created.
s.Do(&testCommand1{})
// Stop the old server
s.Stop()
// create a new server with previous log and snapshot
newS, err := NewServer("1", s.Path(), &testTransporter{}, s.StateMachine(), nil, "")
// Recover from snapshot.
err = newS.LoadSnapshot()
assert.NoError(t, err)
newS.Start()
defer newS.Stop()
// wait for it to become leader
time.Sleep(time.Second)
// ensure server load the previous log
assert.Equal(t, len(newS.LogEntries()), 3, "")
}) })
} }

View File

@ -25,15 +25,16 @@ func writeFileSynced(filename string, data []byte, perm os.FileMode) error {
if err != nil { if err != nil {
return err return err
} }
defer f.Close() // Idempotent
n, err := f.Write(data) n, err := f.Write(data)
if n < len(data) { if err == nil && n < len(data) {
f.Close()
return io.ErrShortWrite return io.ErrShortWrite
} else if err != nil {
return err
} }
err = f.Sync() if err = f.Sync(); err != nil {
if err != nil {
return err return err
} }