Merge branch 'master' of https://github.com/xiangli-cmu/go-raft into xiangli-cmu-master
Conflicts: log.gopull/820/head
commit
3bcf91a39f
|
@ -8,7 +8,6 @@ package raft
|
|||
|
||||
// The request sent to a server to append entries to the log.
|
||||
type AppendEntriesRequest struct {
|
||||
peer *Peer
|
||||
Term uint64 `json:"term"`
|
||||
LeaderName string `json:"leaderName"`
|
||||
PrevLogIndex uint64 `json:"prevLogIndex"`
|
||||
|
@ -19,7 +18,6 @@ type AppendEntriesRequest struct {
|
|||
|
||||
// The response returned from a server appending entries to the log.
|
||||
type AppendEntriesResponse struct {
|
||||
peer *Peer
|
||||
Term uint64 `json:"term"`
|
||||
Success bool `json:"success"`
|
||||
CommitIndex uint64 `json:"commitIndex"`
|
||||
|
|
128
log.go
128
log.go
|
@ -19,10 +19,13 @@ import (
|
|||
type Log struct {
|
||||
ApplyFunc func(Command) error
|
||||
file *os.File
|
||||
path string
|
||||
entries []*LogEntry
|
||||
errors []error
|
||||
commitIndex uint64
|
||||
mutex sync.Mutex
|
||||
startIndex uint64 // the index before the first entry in the Log entries
|
||||
startTerm uint64
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -42,6 +45,17 @@ func NewLog() *Log {
|
|||
//
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func (l *Log) SetStartIndex(i uint64) {
|
||||
l.startIndex = i
|
||||
}
|
||||
|
||||
func (l *Log) StartIndex() uint64 {
|
||||
return l.startIndex
|
||||
}
|
||||
|
||||
func (l *Log) SetStartTerm(t uint64) {
|
||||
l.startTerm = t
|
||||
}
|
||||
//--------------------------------------
|
||||
// Log Indices
|
||||
//--------------------------------------
|
||||
|
@ -52,7 +66,15 @@ func (l *Log) CurrentIndex() uint64 {
|
|||
defer l.mutex.Unlock()
|
||||
|
||||
if len(l.entries) == 0 {
|
||||
return 0
|
||||
return l.startIndex
|
||||
}
|
||||
return l.entries[len(l.entries)-1].Index
|
||||
}
|
||||
|
||||
// The current index in the log without locking
|
||||
func (l *Log) internalCurrentIndex() uint64 {
|
||||
if len(l.entries) == 0 {
|
||||
return l.startIndex
|
||||
}
|
||||
return l.entries[len(l.entries)-1].Index
|
||||
}
|
||||
|
@ -96,7 +118,7 @@ func (l *Log) CurrentTerm() uint64 {
|
|||
defer l.mutex.Unlock()
|
||||
|
||||
if len(l.entries) == 0 {
|
||||
return 0
|
||||
return l.startTerm
|
||||
}
|
||||
return l.entries[len(l.entries)-1].Term
|
||||
}
|
||||
|
@ -165,7 +187,7 @@ func (l *Log) Open(path string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.path = path
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -196,7 +218,7 @@ func (l *Log) ContainsEntry(index uint64, term uint64) bool {
|
|||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
|
||||
if index == 0 || index > uint64(len(l.entries)) {
|
||||
if index <= l.startIndex || index > (l.startIndex + uint64(len(l.entries))) {
|
||||
return false
|
||||
}
|
||||
return (l.entries[index-1].Term == term)
|
||||
|
@ -206,18 +228,17 @@ func (l *Log) ContainsEntry(index uint64, term uint64) bool {
|
|||
// the term of the index provided.
|
||||
func (l *Log) GetEntriesAfter(index uint64) ([]*LogEntry, uint64) {
|
||||
// Return an error if the index doesn't exist.
|
||||
if index > uint64(len(l.entries)) {
|
||||
if index > (uint64(len(l.entries)) + l.startIndex) {
|
||||
panic(fmt.Sprintf("raft.Log: Index is beyond end of log: %v", index))
|
||||
}
|
||||
|
||||
// If we're going from the beginning of the log then return the whole log.
|
||||
if index == 0 {
|
||||
return l.entries, 0
|
||||
if index == l.startIndex {
|
||||
return l.entries, l.startTerm
|
||||
}
|
||||
|
||||
// Determine the term at the given entry and return a subslice.
|
||||
term := l.entries[index-1].Term
|
||||
return l.entries[index:], term
|
||||
term := l.entries[index - 1 - l.startIndex].Term
|
||||
return l.entries[index - l.startIndex:], term
|
||||
}
|
||||
|
||||
// Retrieves the error returned from an entry. The error can only exist after
|
||||
|
@ -250,11 +271,24 @@ func (l *Log) CommitInfo() (index uint64, term uint64) {
|
|||
return 0, 0
|
||||
}
|
||||
|
||||
// no new commit log after snapshot
|
||||
if l.commitIndex == l.startIndex {
|
||||
return l.startIndex, l.startTerm
|
||||
}
|
||||
|
||||
// Return the last index & term from the last committed entry.
|
||||
lastCommitEntry := l.entries[l.commitIndex-1]
|
||||
lastCommitEntry := l.entries[l.commitIndex - 1 - l.startIndex]
|
||||
return lastCommitEntry.Index, lastCommitEntry.Term
|
||||
}
|
||||
|
||||
// Updates the commit index
|
||||
func (l *Log) UpdateCommitIndex(index uint64) {
|
||||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
l.commitIndex = index
|
||||
|
||||
}
|
||||
|
||||
// Updates the commit index and writes entries after that index to the stable storage.
|
||||
func (l *Log) SetCommitIndex(index uint64) error {
|
||||
l.mutex.Lock()
|
||||
|
@ -264,13 +298,13 @@ func (l *Log) SetCommitIndex(index uint64) error {
|
|||
if index < l.commitIndex {
|
||||
return fmt.Errorf("raft.Log: Commit index (%d) ahead of requested commit index (%d)", l.commitIndex, index)
|
||||
}
|
||||
if index > uint64(len(l.entries)) {
|
||||
if index > l.startIndex + uint64(len(l.entries)) {
|
||||
return fmt.Errorf("raft.Log: Commit index (%d) out of range (%d)", index, len(l.entries))
|
||||
}
|
||||
|
||||
// Find all entries whose index is between the previous index and the current index.
|
||||
for i := l.commitIndex + 1; i <= index; i++ {
|
||||
entryIndex := i-1
|
||||
entryIndex := i - 1 - l.startIndex
|
||||
entry := l.entries[entryIndex]
|
||||
|
||||
// Write to storage.
|
||||
|
@ -304,23 +338,23 @@ func (l *Log) Truncate(index uint64, term uint64) error {
|
|||
}
|
||||
|
||||
// Do not truncate past end of entries.
|
||||
if index > uint64(len(l.entries)) {
|
||||
if index > l.startIndex + uint64(len(l.entries)) {
|
||||
return fmt.Errorf("raft.Log: Entry index does not exist (MAX=%v): (IDX=%v, TERM=%v)", len(l.entries), index, term)
|
||||
}
|
||||
|
||||
// If we're truncating everything then just clear the entries.
|
||||
if index == 0 {
|
||||
if index == l.startIndex {
|
||||
l.entries = []*LogEntry{}
|
||||
} else {
|
||||
// Do not truncate if the entry at index does not have the matching term.
|
||||
entry := l.entries[index-1]
|
||||
entry := l.entries[index - l.startIndex - 1]
|
||||
if len(l.entries) > 0 && entry.Term != term {
|
||||
return fmt.Errorf("raft.Log: Entry at index does not have matching term (%v): (IDX=%v, TERM=%v)", entry.Term, index, term)
|
||||
}
|
||||
|
||||
// Otherwise truncate up to the desired entry.
|
||||
if index < uint64(len(l.entries)) {
|
||||
l.entries = l.entries[0:index]
|
||||
if index < l.startIndex + uint64(len(l.entries)) {
|
||||
l.entries = l.entries[0:index - l.startIndex]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -378,3 +412,61 @@ func (l *Log) appendEntry(entry *LogEntry) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
//--------------------------------------
|
||||
// Log compaction
|
||||
//--------------------------------------
|
||||
|
||||
// compaction the log before index
|
||||
func (l *Log) Compact(index uint64, term uint64) error {
|
||||
var entries []*LogEntry
|
||||
|
||||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
|
||||
// nothing to compaction
|
||||
// the index may be greater than the current index if
|
||||
// we just recovery from on snapshot
|
||||
if index >= l.internalCurrentIndex() {
|
||||
entries = make([]*LogEntry, 0)
|
||||
} else {
|
||||
|
||||
// get all log entries after index
|
||||
entries = l.entries[index - l.startIndex:]
|
||||
}
|
||||
|
||||
// create a new log file and add all the entries
|
||||
file, err := os.OpenFile(l.path + ".new", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
err = entry.Encode(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// close the current log file
|
||||
l.file.Close()
|
||||
|
||||
// remove the current log file to .bak
|
||||
err = os.Remove(l.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rename the new log file
|
||||
err = os.Rename(l.path + ".new", l.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.file = file
|
||||
|
||||
// compaction the in memory log
|
||||
l.entries = entries
|
||||
l.startIndex = index
|
||||
l.startTerm = term
|
||||
return nil
|
||||
}
|
||||
|
|
85
peer.go
85
peer.go
|
@ -99,13 +99,71 @@ func (p *Peer) stop() {
|
|||
// Flush
|
||||
//--------------------------------------
|
||||
|
||||
// Sends an AppendEntries RPC but does not obtain a lock on the server. This
|
||||
// method should only be called from the server.
|
||||
func (p *Peer) internalFlush() (uint64, bool, error) {
|
||||
// if internal is set true, sends an AppendEntries RPC but does not obtain a lock
|
||||
// on the server.
|
||||
func (p *Peer) flush(internal bool) (uint64, bool, error) {
|
||||
// Retrieve the peer data within a lock that is separate from the
|
||||
// server lock when creating the request. Otherwise a deadlock can
|
||||
// occur.
|
||||
p.mutex.Lock()
|
||||
server, prevLogIndex := p.server, p.prevLogIndex
|
||||
p.mutex.Unlock()
|
||||
|
||||
var req *AppendEntriesRequest
|
||||
snapShotNeeded := false
|
||||
|
||||
// we need to hold the log lock to create AppendEntriesRequest
|
||||
// avoid snapshot to delete the desired entries before AEQ()
|
||||
server.log.mutex.Lock()
|
||||
if prevLogIndex >= server.log.StartIndex() {
|
||||
if internal {
|
||||
req = server.createInternalAppendEntriesRequest(prevLogIndex)
|
||||
} else {
|
||||
req = server.createAppendEntriesRequest(prevLogIndex)
|
||||
}
|
||||
} else {
|
||||
snapShotNeeded = true
|
||||
}
|
||||
server.log.mutex.Unlock()
|
||||
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
req := p.server.createInternalAppendEntriesRequest(p.prevLogIndex)
|
||||
return p.sendFlushRequest(req)
|
||||
if snapShotNeeded {
|
||||
req := server.createSnapshotRequest()
|
||||
return p.sendSnapshotRequest(req)
|
||||
} else {
|
||||
return p.sendFlushRequest(req)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// send Snapshot Request
|
||||
func (p *Peer) sendSnapshotRequest(req *SnapshotRequest) (uint64, bool, error){
|
||||
// Ignore any null requests.
|
||||
if req == nil {
|
||||
return 0, false, errors.New("raft.Peer: Request required")
|
||||
}
|
||||
|
||||
// Generate an snapshot request based on the state of the server and
|
||||
// log. Send the request through the user-provided handler and process the
|
||||
// result.
|
||||
resp, err := p.server.transporter.SendSnapshotRequest(p.server, p, req)
|
||||
p.heartbeatTimer.Reset()
|
||||
if resp == nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
// If successful then update the previous log index. If it was
|
||||
// unsuccessful then decrement the previous log index and we'll try again
|
||||
// next time.
|
||||
if resp.Success {
|
||||
p.prevLogIndex = req.LastIndex
|
||||
|
||||
} else {
|
||||
panic(resp)
|
||||
}
|
||||
|
||||
return resp.Term, resp.Success, err
|
||||
}
|
||||
|
||||
// Flushes a request through the server's transport.
|
||||
|
@ -119,6 +177,7 @@ func (p *Peer) sendFlushRequest(req *AppendEntriesRequest) (uint64, bool, error)
|
|||
// log. Send the request through the user-provided handler and process the
|
||||
// result.
|
||||
resp, err := p.server.transporter.SendAppendEntriesRequest(p.server, p, req)
|
||||
|
||||
p.heartbeatTimer.Reset()
|
||||
if resp == nil {
|
||||
return 0, false, err
|
||||
|
@ -153,10 +212,11 @@ func (p *Peer) sendFlushRequest(req *AppendEntriesRequest) (uint64, bool, error)
|
|||
// Listens to the heartbeat timeout and flushes an AppendEntries RPC.
|
||||
func (p *Peer) heartbeatTimeoutFunc(startChannel chan bool) {
|
||||
startChannel <- true
|
||||
|
||||
|
||||
for {
|
||||
// Grab the current timer channel.
|
||||
p.mutex.Lock()
|
||||
|
||||
var c chan time.Time
|
||||
if p.heartbeatTimer != nil {
|
||||
c = p.heartbeatTimer.C()
|
||||
|
@ -171,19 +231,8 @@ func (p *Peer) heartbeatTimeoutFunc(startChannel chan bool) {
|
|||
// Flush the peer when we get a heartbeat timeout. If the channel is
|
||||
// closed then the peer is getting cleaned up and we should exit.
|
||||
if _, ok := <-c; ok {
|
||||
// Retrieve the peer data within a lock that is separate from the
|
||||
// server lock when creating the request. Otherwise a deadlock can
|
||||
// occur.
|
||||
p.mutex.Lock()
|
||||
server, prevLogIndex := p.server, p.prevLogIndex
|
||||
p.mutex.Unlock()
|
||||
p.flush(false)
|
||||
|
||||
// Lock the server to create a request.
|
||||
req := server.createAppendEntriesRequest(prevLogIndex)
|
||||
|
||||
p.mutex.Lock()
|
||||
p.sendFlushRequest(req)
|
||||
p.mutex.Unlock()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
|
205
server.go
205
server.go
|
@ -5,6 +5,10 @@ import (
|
|||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"os"
|
||||
"sort"
|
||||
"path"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -56,6 +60,9 @@ type Server struct {
|
|||
mutex sync.Mutex
|
||||
electionTimer *Timer
|
||||
heartbeatTimeout time.Duration
|
||||
currentSnapshot *Snapshot
|
||||
lastSnapshot *Snapshot
|
||||
stateMachine StateMachine
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -248,6 +255,12 @@ func (s *Server) Start() error {
|
|||
return errors.New("raft.Server: Server already running")
|
||||
}
|
||||
|
||||
// create snapshot dir if not exist
|
||||
os.Mkdir(s.path + "/snapshot", 0700)
|
||||
|
||||
// ## open recovery from the newest snapShot
|
||||
//s.LoadSnapshot()
|
||||
|
||||
// Initialize the log and load it up.
|
||||
if err := s.log.Open(s.LogPath()); err != nil {
|
||||
s.unload()
|
||||
|
@ -356,6 +369,11 @@ func (s *Server) do(command Command) error {
|
|||
// Capture the term that this command is executing within.
|
||||
currentTerm := s.currentTerm
|
||||
|
||||
// // TEMP to solve the issue 18
|
||||
// for _, peer := range s.peers {
|
||||
// peer.pause()
|
||||
// }
|
||||
|
||||
// Add a new entry to the log.
|
||||
entry := s.log.CreateEntry(s.currentTerm, command)
|
||||
if err := s.log.AppendEntry(entry); err != nil {
|
||||
|
@ -367,18 +385,22 @@ func (s *Server) do(command Command) error {
|
|||
for _, _peer := range s.peers {
|
||||
peer := _peer
|
||||
go func() {
|
||||
term, success, err := peer.internalFlush()
|
||||
|
||||
term, success, err := peer.flush(true)
|
||||
|
||||
// Demote if we encounter a higher term.
|
||||
if err != nil {
|
||||
|
||||
return
|
||||
} else if term > currentTerm {
|
||||
s.mutex.Lock()
|
||||
s.setCurrentTerm(term)
|
||||
|
||||
if s.electionTimer != nil {
|
||||
s.electionTimer.Reset()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -397,6 +419,9 @@ loop:
|
|||
// If we received enough votes then stop waiting for more votes.
|
||||
if responseCount >= s.QuorumSize() {
|
||||
committed = true
|
||||
// for _, peer := range s.peers {
|
||||
// peer.resume()
|
||||
// }
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -409,6 +434,9 @@ loop:
|
|||
}
|
||||
responseCount++
|
||||
case <-afterBetween(s.ElectionTimeout(), s.ElectionTimeout()*2):
|
||||
// for _, peer := range s.peers {
|
||||
// peer.resume()
|
||||
// }
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
@ -430,7 +458,6 @@ loop:
|
|||
func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// If the server is stopped then reject it.
|
||||
if !s.Running() {
|
||||
return NewAppendEntriesResponse(s.currentTerm, false, 0), fmt.Errorf("raft.Server: Server stopped")
|
||||
|
@ -449,7 +476,7 @@ func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesRespons
|
|||
if s.electionTimer != nil {
|
||||
s.electionTimer.Reset()
|
||||
}
|
||||
|
||||
|
||||
// Reject if log doesn't contain a matching previous entry.
|
||||
if err := s.log.Truncate(req.PrevLogIndex, req.PrevLogTerm); err != nil {
|
||||
return NewAppendEntriesResponse(s.currentTerm, false, s.log.CommitIndex()), err
|
||||
|
@ -463,7 +490,7 @@ func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesRespons
|
|||
// Commit up to the commit index.
|
||||
if err := s.log.SetCommitIndex(req.CommitIndex); err != nil {
|
||||
return NewAppendEntriesResponse(s.currentTerm, false, s.log.CommitIndex()), err
|
||||
}
|
||||
}
|
||||
|
||||
return NewAppendEntriesResponse(s.currentTerm, true, s.log.CommitIndex()), nil
|
||||
}
|
||||
|
@ -494,6 +521,7 @@ func (s *Server) createInternalAppendEntriesRequest(prevLogIndex uint64) *Append
|
|||
// server is elected then true is returned. If another server is elected then
|
||||
// false is returned.
|
||||
func (s *Server) promote() (bool, error) {
|
||||
|
||||
for {
|
||||
// Start a new election.
|
||||
term, lastLogIndex, lastLogTerm, err := s.promoteToCandidate()
|
||||
|
@ -589,7 +617,6 @@ func (s *Server) promoteToCandidate() (uint64, uint64, uint64, error) {
|
|||
s.leader = ""
|
||||
// Pause the election timer while we're a candidate.
|
||||
s.electionTimer.Pause()
|
||||
|
||||
// Return server state so we can check for it during leader promotion.
|
||||
lastLogIndex, lastLogTerm := s.log.CommitInfo()
|
||||
return s.currentTerm, lastLogIndex, lastLogTerm, nil
|
||||
|
@ -634,7 +661,6 @@ func (s *Server) promoteToLeader(term uint64, lastLogIndex uint64, lastLogTerm u
|
|||
func (s *Server) RequestVote(req *RequestVoteRequest) (*RequestVoteResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// Fail if the server is not running.
|
||||
if !s.Running() {
|
||||
return NewRequestVoteResponse(s.currentTerm, false), fmt.Errorf("raft.Server: Server is stopped")
|
||||
|
@ -725,8 +751,8 @@ func (s *Server) AddPeer(name string) error {
|
|||
peer.resume()
|
||||
}
|
||||
s.peers[peer.name] = peer
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -756,3 +782,168 @@ func (s *Server) RemovePeer(name string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
//--------------------------------------
|
||||
// Log compaction
|
||||
//--------------------------------------
|
||||
|
||||
// Creates a snapshot request.
|
||||
func (s *Server) createSnapshotRequest() *SnapshotRequest {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return NewSnapshotRequest(s.name, s.lastSnapshot)
|
||||
}
|
||||
|
||||
// The background snapshot function
|
||||
func (s *Server) Snapshot() {
|
||||
for {
|
||||
s.takeSnapshot()
|
||||
|
||||
// TODO: change this... to something reasonable
|
||||
time.Sleep(5000 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) takeSnapshot() error {
|
||||
//TODO put a snapshot mutex
|
||||
if s.currentSnapshot != nil {
|
||||
return errors.New("handling snapshot")
|
||||
}
|
||||
|
||||
lastIndex, lastTerm := s.log.CommitInfo()
|
||||
|
||||
if lastIndex == 0 || lastTerm == 0 {
|
||||
return errors.New("No logs")
|
||||
}
|
||||
|
||||
path := s.SnapshotPath(lastIndex, lastTerm)
|
||||
|
||||
state, err := s.stateMachine.Save()
|
||||
|
||||
if err !=nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.currentSnapshot = &Snapshot{lastIndex, lastTerm, state, path}
|
||||
|
||||
s.saveSnapshot()
|
||||
|
||||
s.log.Compact(lastIndex, lastTerm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieves the log path for the server.
|
||||
func (s *Server) saveSnapshot() error {
|
||||
|
||||
if s.currentSnapshot == nil {
|
||||
return errors.New("no snapshot to save")
|
||||
}
|
||||
|
||||
err := s.currentSnapshot.Save()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp := s.lastSnapshot
|
||||
s.lastSnapshot = s.currentSnapshot
|
||||
|
||||
// delete the previous snapshot if there is any change
|
||||
if tmp != nil && !(tmp.lastIndex == s.lastSnapshot.lastIndex && tmp.lastTerm == s.lastSnapshot.lastTerm) {
|
||||
tmp.Remove()
|
||||
}
|
||||
s.currentSnapshot = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieves the log path for the server.
|
||||
func (s *Server) SnapshotPath(lastIndex uint64, lastTerm uint64) string {
|
||||
return path.Join(s.path, "snapshot", fmt.Sprintf("%v_%v.ss", lastTerm, lastIndex))
|
||||
}
|
||||
|
||||
|
||||
func (s *Server) SnapshotRecovery(req *SnapshotRequest) (*SnapshotResponse, error){
|
||||
//
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.stateMachine.Recovery(req.State)
|
||||
|
||||
//update term and index
|
||||
s.currentTerm = req.LastTerm
|
||||
s.log.UpdateCommitIndex(req.LastIndex)
|
||||
snapshotPath := s.SnapshotPath(req.LastIndex, req.LastTerm)
|
||||
s.currentSnapshot = &Snapshot{req.LastIndex, req.LastTerm, req.State, snapshotPath}
|
||||
s.saveSnapshot()
|
||||
s.log.Compact(req.LastIndex, req.LastTerm)
|
||||
|
||||
|
||||
return NewSnapshotResponse(req.LastTerm, true, req.LastIndex), nil
|
||||
|
||||
}
|
||||
|
||||
// Load a snapshot at restart
|
||||
func (s *Server) LoadSnapshot() error {
|
||||
dir, err := os.OpenFile(path.Join(s.path, "snapshot"), os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
dir.Close()
|
||||
panic(err)
|
||||
}
|
||||
|
||||
filenames, err := dir.Readdirnames(-1)
|
||||
|
||||
if err != nil {
|
||||
dir.Close()
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dir.Close()
|
||||
if len(filenames) == 0 {
|
||||
return errors.New("no snapshot")
|
||||
}
|
||||
|
||||
// not sure how many snapshot we should keep
|
||||
sort.Strings(filenames)
|
||||
snapshotPath := path.Join(s.path, "snapshot", filenames[len(filenames) - 1])
|
||||
|
||||
// should not file
|
||||
file, err := os.OpenFile(snapshotPath, os.O_RDONLY, 0)
|
||||
defer file.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// TODO check checksum first
|
||||
// TODO recovery state machine
|
||||
|
||||
var state []byte
|
||||
var checksum, lastIndex, lastTerm uint64
|
||||
|
||||
n , err := fmt.Fscanf(file, "%08x\n%v\n%v", &checksum, &lastIndex, &lastTerm)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != 3 {
|
||||
return errors.New("Bad snapshot file")
|
||||
}
|
||||
|
||||
state, _ = ioutil.ReadAll(file)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.lastSnapshot = &Snapshot{lastIndex, lastTerm, state, snapshotPath}
|
||||
err = s.stateMachine.Recovery(state)
|
||||
|
||||
s.log.SetStartTerm(lastTerm)
|
||||
s.log.SetStartIndex(lastIndex)
|
||||
s.log.UpdateCommitIndex(lastIndex)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -389,16 +389,12 @@ func TestServerMultiNode(t *testing.T) {
|
|||
time.Sleep(100 * time.Millisecond)
|
||||
leader.Stop()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check that either server 2 or 3 is the leader now.
|
||||
mutex.Lock()
|
||||
if servers["2"].State() != Leader && servers["3"].State() != Leader {
|
||||
t.Fatalf("Expected leader re-election: 2=%v, 3=%v", servers["2"].state, servers["3"].state)
|
||||
t.Fatalf("Expected leader re-election: 2=%v, 3=%v\n", servers["2"].state, servers["3"].state)
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
// Stop the servers.
|
||||
for _, server := range servers {
|
||||
server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package raft
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
//
|
||||
// Typedefs
|
||||
//
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// the in memory SnapShot struct
|
||||
// TODO add cluster configuration
|
||||
type Snapshot struct {
|
||||
lastIndex uint64
|
||||
lastTerm uint64
|
||||
// cluster configuration.
|
||||
state []byte
|
||||
path string
|
||||
}
|
||||
|
||||
// Save the snapshot to a file
|
||||
func (ss *Snapshot) Save() error {
|
||||
// Write machine state to temporary buffer.
|
||||
var b bytes.Buffer
|
||||
|
||||
if _, err := fmt.Fprintf(&b, "%v", 2); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate checksum.
|
||||
checksum := crc32.ChecksumIEEE(b.Bytes())
|
||||
|
||||
// open file
|
||||
file, err := os.OpenFile(ss.path, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
|
||||
// Write snapshot with checksum.
|
||||
if _, err = fmt.Fprintf(file, "%08x\n%v\n%v\n", checksum, ss.lastIndex,
|
||||
ss.lastTerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = file.Write(ss.state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// force the change writting to disk
|
||||
syscall.Fsync(int(file.Fd()))
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the file of the snapshot
|
||||
func (ss *Snapshot) Remove() error {
|
||||
err := os.Remove(ss.path)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package raft
|
||||
|
||||
// The request sent to a server to start from the snapshot.
|
||||
type SnapshotRequest struct {
|
||||
LeaderName string `json:"leaderName"`
|
||||
LastIndex uint64 `json:"lastTerm"`
|
||||
LastTerm uint64 `json:"lastIndex"`
|
||||
State []byte `json:"state"`
|
||||
}
|
||||
|
||||
// The response returned from a server appending entries to the log.
|
||||
type SnapshotResponse struct {
|
||||
Term uint64 `json:"term"`
|
||||
Success bool `json:"success"`
|
||||
CommitIndex uint64 `json:"commitIndex"`
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
//
|
||||
// Constructors
|
||||
//
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Creates a new Snapshot request.
|
||||
func NewSnapshotRequest(leaderName string, snapshot *Snapshot) *SnapshotRequest {
|
||||
return &SnapshotRequest{
|
||||
LeaderName: leaderName,
|
||||
LastIndex: snapshot.lastIndex,
|
||||
LastTerm: snapshot.lastTerm,
|
||||
State: snapshot.state,
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a new Snapshot response.
|
||||
func NewSnapshotResponse(term uint64, success bool, commitIndex uint64) *SnapshotResponse {
|
||||
return &SnapshotResponse{
|
||||
Term: term,
|
||||
Success: success,
|
||||
CommitIndex: commitIndex,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
package raft
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// test take and send snapshot
|
||||
func TestTakeAndSendSnapshot(t *testing.T) {
|
||||
// Initialize the servers.
|
||||
var mutex sync.Mutex
|
||||
//fmt.Println("---Snapshot Test---")
|
||||
names := []string{"1", "2", "3"}
|
||||
servers := map[string]*Server{}
|
||||
for _, server := range servers {
|
||||
defer server.Stop()
|
||||
}
|
||||
transporter := &testTransporter{}
|
||||
transporter.sendVoteRequestFunc = func(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error) {
|
||||
mutex.Lock()
|
||||
s := servers[peer.name]
|
||||
mutex.Unlock()
|
||||
resp, err := s.RequestVote(req)
|
||||
return resp, err
|
||||
}
|
||||
transporter.sendAppendEntriesRequestFunc = func(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error) {
|
||||
mutex.Lock()
|
||||
s := servers[peer.name]
|
||||
mutex.Unlock()
|
||||
resp, err := s.AppendEntries(req)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
transporter.sendSnapshotRequestFunc = func(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error) {
|
||||
mutex.Lock()
|
||||
s := servers[peer.name]
|
||||
mutex.Unlock()
|
||||
resp, err := s.SnapshotRecovery(req)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
stateMachine := &testStateMachine{}
|
||||
|
||||
stateMachine.saveFunc = func() ([]byte,error) {
|
||||
return []byte{0x8},nil
|
||||
}
|
||||
|
||||
stateMachine.recoveryFunc = func(state []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var leader *Server
|
||||
for _, name := range names {
|
||||
server := newTestServer(name, transporter)
|
||||
server.stateMachine = stateMachine
|
||||
server.SetElectionTimeout(testElectionTimeout)
|
||||
server.SetHeartbeatTimeout(testHeartbeatTimeout)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Unable to start server[%s]: %v", name, err)
|
||||
}
|
||||
if name == "1" {
|
||||
leader = server
|
||||
if err := server.Initialize(); err != nil {
|
||||
t.Fatalf("Unable to initialize server[%s]: %v", name, err)
|
||||
}
|
||||
}
|
||||
if err := leader.Do(&joinCommand{Name:name}); err != nil {
|
||||
t.Fatalf("Unable to join server[%s]: %v", name, err)
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
servers[name] = server
|
||||
mutex.Unlock()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check that two peers exist on leader.
|
||||
mutex.Lock()
|
||||
if leader.MemberCount() != 3 {
|
||||
t.Fatalf("Expected member count to be 3, got %v", leader.MemberCount())
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
// commit single entry.
|
||||
err := leader.Do(&TestCommand1{"foo", 10})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
index, term := leader.log.CommitInfo()
|
||||
|
||||
// three join and one test Command
|
||||
if !(index == 4 && term == 1) {
|
||||
t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term)
|
||||
}
|
||||
|
||||
leader.takeSnapshot()
|
||||
|
||||
logLen := len(leader.log.entries)
|
||||
|
||||
if logLen != 0 {
|
||||
t.Fatalf("Invalid logLen [Len=%v]", logLen)
|
||||
}
|
||||
|
||||
if leader.log.startIndex != 4 || leader.log.startTerm != 1 {
|
||||
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
|
||||
leader.log.startIndex, leader.log.startTerm)
|
||||
}
|
||||
|
||||
// test send snapshot to a new node
|
||||
// send from heartbeat
|
||||
newServer := newTestServer("4", transporter)
|
||||
newServer.stateMachine = stateMachine
|
||||
if err := newServer.Start(); err != nil {
|
||||
t.Fatalf("Unable to start server[4]: %v", err)
|
||||
}
|
||||
|
||||
if err := leader.Do(&joinCommand{Name:"4"}); err != nil {
|
||||
t.Fatalf("Unable to join server[4]: %v", err)
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
servers["4"] = newServer
|
||||
mutex.Unlock()
|
||||
|
||||
// wait for heartbeat :P
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if leader.log.startIndex != 4 || leader.log.startTerm != 1 {
|
||||
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
|
||||
leader.log.startIndex, leader.log.startTerm)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
}
|
||||
|
||||
|
||||
func TestStartFormSnapshot(t *testing.T) {
|
||||
server := newTestServer("1", &testTransporter{})
|
||||
|
||||
stateMachine := &testStateMachine{}
|
||||
stateMachine.saveFunc = func() ([]byte,error) {
|
||||
return []byte{0x60,0x61,0x62,0x63,0x64,0x65},nil
|
||||
}
|
||||
|
||||
stateMachine.recoveryFunc = func(state []byte) error {
|
||||
expect := []byte{0x60,0x61,0x62,0x63,0x64,0x65}
|
||||
if !(bytes.Equal(state, expect)) {
|
||||
t.Fatalf("Invalid State [Expcet=%v, Actual=%v]", expect, state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
server.stateMachine = stateMachine
|
||||
oldPath := server.path
|
||||
server.Start()
|
||||
|
||||
server.Initialize()
|
||||
|
||||
// commit single entry.
|
||||
err := server.Do(&TestCommand1{"foo", 10})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server.takeSnapshot()
|
||||
|
||||
logLen := len(server.log.entries)
|
||||
|
||||
if logLen != 0 {
|
||||
t.Fatalf("Invalid logLen [Len=%v]", logLen)
|
||||
}
|
||||
|
||||
if server.log.startIndex != 1 || server.log.startTerm != 1 {
|
||||
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
|
||||
server.log.startIndex, server.log.startTerm)
|
||||
}
|
||||
|
||||
server.Stop()
|
||||
|
||||
server = newTestServer("1", &testTransporter{})
|
||||
server.stateMachine = stateMachine
|
||||
// reset the oldPath
|
||||
server.path = oldPath
|
||||
|
||||
server.Start()
|
||||
|
||||
logLen = len(server.log.entries)
|
||||
|
||||
if logLen != 0 {
|
||||
t.Fatalf("Invalid logLen [Len=%v]", logLen)
|
||||
}
|
||||
|
||||
if index, term := server.log.CommitInfo(); !(index == 0 && term == 0) {
|
||||
t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term)
|
||||
}
|
||||
|
||||
if server.log.startIndex != 0 || server.log.startTerm != 0 {
|
||||
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
|
||||
server.log.startIndex, server.log.startTerm)
|
||||
}
|
||||
|
||||
server.LoadSnapshot()
|
||||
if server.log.startIndex != 1 || server.log.startTerm != 1 {
|
||||
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
|
||||
server.log.startIndex, server.log.startTerm)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package raft
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
//
|
||||
// Typedefs
|
||||
//
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// StateMachine is the interface for allowing the host application to save and
|
||||
// recovery the state machine
|
||||
type StateMachine interface {
|
||||
Save() ([]byte, error)
|
||||
Recovery([]byte) error
|
||||
}
|
21
test.go
21
test.go
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -98,6 +98,7 @@ func newTestCluster(names []string, transporter Transporter, lookup map[string]*
|
|||
type testTransporter struct {
|
||||
sendVoteRequestFunc func(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error)
|
||||
sendAppendEntriesRequestFunc func(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error)
|
||||
sendSnapshotRequestFunc func(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error)
|
||||
}
|
||||
|
||||
func (t *testTransporter) SendVoteRequest(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error) {
|
||||
|
@ -108,6 +109,24 @@ func (t *testTransporter) SendAppendEntriesRequest(server *Server, peer *Peer, r
|
|||
return t.sendAppendEntriesRequestFunc(server, peer, req)
|
||||
}
|
||||
|
||||
func (t *testTransporter) SendSnapshotRequest(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error) {
|
||||
return t.sendSnapshotRequestFunc(server, peer, req)
|
||||
}
|
||||
|
||||
|
||||
type testStateMachine struct {
|
||||
saveFunc func() ([]byte, error)
|
||||
recoveryFunc func([]byte) error
|
||||
}
|
||||
|
||||
func (sm *testStateMachine) Save() ([]byte, error) {
|
||||
return sm.saveFunc()
|
||||
}
|
||||
|
||||
func (sm *testStateMachine) Recovery(state []byte) error {
|
||||
return sm.recoveryFunc(state)
|
||||
}
|
||||
|
||||
|
||||
//--------------------------------------
|
||||
// Join Command
|
||||
|
|
|
@ -11,4 +11,5 @@ package raft
|
|||
type Transporter interface {
|
||||
SendVoteRequest(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error)
|
||||
SendAppendEntriesRequest(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error)
|
||||
SendSnapshotRequest(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue