Merge branch 'master' of https://github.com/xiangli-cmu/go-raft into xiangli-cmu-master

Conflicts:
	log.go
pull/820/head
Ben Johnson 2013-06-07 17:53:27 -04:00
commit 3bcf91a39f
11 changed files with 734 additions and 53 deletions

View File

@ -8,7 +8,6 @@ package raft
// The request sent to a server to append entries to the log.
type AppendEntriesRequest struct {
peer *Peer
Term uint64 `json:"term"`
LeaderName string `json:"leaderName"`
PrevLogIndex uint64 `json:"prevLogIndex"`
@ -19,7 +18,6 @@ type AppendEntriesRequest struct {
// The response returned from a server appending entries to the log.
type AppendEntriesResponse struct {
peer *Peer
Term uint64 `json:"term"`
Success bool `json:"success"`
CommitIndex uint64 `json:"commitIndex"`

128
log.go
View File

@ -19,10 +19,13 @@ import (
type Log struct {
ApplyFunc func(Command) error
file *os.File
path string
entries []*LogEntry
errors []error
commitIndex uint64
mutex sync.Mutex
startIndex uint64 // the index before the first entry in the Log entries
startTerm uint64
}
//------------------------------------------------------------------------------
@ -42,6 +45,17 @@ func NewLog() *Log {
//
//------------------------------------------------------------------------------
func (l *Log) SetStartIndex(i uint64) {
l.startIndex = i
}
func (l *Log) StartIndex() uint64 {
return l.startIndex
}
func (l *Log) SetStartTerm(t uint64) {
l.startTerm = t
}
//--------------------------------------
// Log Indices
//--------------------------------------
@ -52,7 +66,15 @@ func (l *Log) CurrentIndex() uint64 {
defer l.mutex.Unlock()
if len(l.entries) == 0 {
return 0
return l.startIndex
}
return l.entries[len(l.entries)-1].Index
}
// The current index in the log without locking
func (l *Log) internalCurrentIndex() uint64 {
if len(l.entries) == 0 {
return l.startIndex
}
return l.entries[len(l.entries)-1].Index
}
@ -96,7 +118,7 @@ func (l *Log) CurrentTerm() uint64 {
defer l.mutex.Unlock()
if len(l.entries) == 0 {
return 0
return l.startTerm
}
return l.entries[len(l.entries)-1].Term
}
@ -165,7 +187,7 @@ func (l *Log) Open(path string) error {
if err != nil {
return err
}
l.path = path
return nil
}
@ -196,7 +218,7 @@ func (l *Log) ContainsEntry(index uint64, term uint64) bool {
l.mutex.Lock()
defer l.mutex.Unlock()
if index == 0 || index > uint64(len(l.entries)) {
if index <= l.startIndex || index > (l.startIndex + uint64(len(l.entries))) {
return false
}
return (l.entries[index-1].Term == term)
@ -206,18 +228,17 @@ func (l *Log) ContainsEntry(index uint64, term uint64) bool {
// the term of the index provided.
func (l *Log) GetEntriesAfter(index uint64) ([]*LogEntry, uint64) {
// Return an error if the index doesn't exist.
if index > uint64(len(l.entries)) {
if index > (uint64(len(l.entries)) + l.startIndex) {
panic(fmt.Sprintf("raft.Log: Index is beyond end of log: %v", index))
}
// If we're going from the beginning of the log then return the whole log.
if index == 0 {
return l.entries, 0
if index == l.startIndex {
return l.entries, l.startTerm
}
// Determine the term at the given entry and return a subslice.
term := l.entries[index-1].Term
return l.entries[index:], term
term := l.entries[index - 1 - l.startIndex].Term
return l.entries[index - l.startIndex:], term
}
// Retrieves the error returned from an entry. The error can only exist after
@ -250,11 +271,24 @@ func (l *Log) CommitInfo() (index uint64, term uint64) {
return 0, 0
}
// no new commit log after snapshot
if l.commitIndex == l.startIndex {
return l.startIndex, l.startTerm
}
// Return the last index & term from the last committed entry.
lastCommitEntry := l.entries[l.commitIndex-1]
lastCommitEntry := l.entries[l.commitIndex - 1 - l.startIndex]
return lastCommitEntry.Index, lastCommitEntry.Term
}
// Updates the commit index
func (l *Log) UpdateCommitIndex(index uint64) {
l.mutex.Lock()
defer l.mutex.Unlock()
l.commitIndex = index
}
// Updates the commit index and writes entries after that index to the stable storage.
func (l *Log) SetCommitIndex(index uint64) error {
l.mutex.Lock()
@ -264,13 +298,13 @@ func (l *Log) SetCommitIndex(index uint64) error {
if index < l.commitIndex {
return fmt.Errorf("raft.Log: Commit index (%d) ahead of requested commit index (%d)", l.commitIndex, index)
}
if index > uint64(len(l.entries)) {
if index > l.startIndex + uint64(len(l.entries)) {
return fmt.Errorf("raft.Log: Commit index (%d) out of range (%d)", index, len(l.entries))
}
// Find all entries whose index is between the previous index and the current index.
for i := l.commitIndex + 1; i <= index; i++ {
entryIndex := i-1
entryIndex := i - 1 - l.startIndex
entry := l.entries[entryIndex]
// Write to storage.
@ -304,23 +338,23 @@ func (l *Log) Truncate(index uint64, term uint64) error {
}
// Do not truncate past end of entries.
if index > uint64(len(l.entries)) {
if index > l.startIndex + uint64(len(l.entries)) {
return fmt.Errorf("raft.Log: Entry index does not exist (MAX=%v): (IDX=%v, TERM=%v)", len(l.entries), index, term)
}
// If we're truncating everything then just clear the entries.
if index == 0 {
if index == l.startIndex {
l.entries = []*LogEntry{}
} else {
// Do not truncate if the entry at index does not have the matching term.
entry := l.entries[index-1]
entry := l.entries[index - l.startIndex - 1]
if len(l.entries) > 0 && entry.Term != term {
return fmt.Errorf("raft.Log: Entry at index does not have matching term (%v): (IDX=%v, TERM=%v)", entry.Term, index, term)
}
// Otherwise truncate up to the desired entry.
if index < uint64(len(l.entries)) {
l.entries = l.entries[0:index]
if index < l.startIndex + uint64(len(l.entries)) {
l.entries = l.entries[0:index - l.startIndex]
}
}
@ -378,3 +412,61 @@ func (l *Log) appendEntry(entry *LogEntry) error {
return nil
}
//--------------------------------------
// Log compaction
//--------------------------------------
// compaction the log before index
func (l *Log) Compact(index uint64, term uint64) error {
var entries []*LogEntry
l.mutex.Lock()
defer l.mutex.Unlock()
// nothing to compaction
// the index may be greater than the current index if
// we just recovery from on snapshot
if index >= l.internalCurrentIndex() {
entries = make([]*LogEntry, 0)
} else {
// get all log entries after index
entries = l.entries[index - l.startIndex:]
}
// create a new log file and add all the entries
file, err := os.OpenFile(l.path + ".new", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
for _, entry := range entries {
err = entry.Encode(file)
if err != nil {
return err
}
}
// close the current log file
l.file.Close()
// remove the current log file to .bak
err = os.Remove(l.path)
if err != nil {
return err
}
// rename the new log file
err = os.Rename(l.path + ".new", l.path)
if err != nil {
return err
}
l.file = file
// compaction the in memory log
l.entries = entries
l.startIndex = index
l.startTerm = term
return nil
}

85
peer.go
View File

@ -99,13 +99,71 @@ func (p *Peer) stop() {
// Flush
//--------------------------------------
// Sends an AppendEntries RPC but does not obtain a lock on the server. This
// method should only be called from the server.
func (p *Peer) internalFlush() (uint64, bool, error) {
// if internal is set true, sends an AppendEntries RPC but does not obtain a lock
// on the server.
func (p *Peer) flush(internal bool) (uint64, bool, error) {
// Retrieve the peer data within a lock that is separate from the
// server lock when creating the request. Otherwise a deadlock can
// occur.
p.mutex.Lock()
server, prevLogIndex := p.server, p.prevLogIndex
p.mutex.Unlock()
var req *AppendEntriesRequest
snapShotNeeded := false
// we need to hold the log lock to create AppendEntriesRequest
// avoid snapshot to delete the desired entries before AEQ()
server.log.mutex.Lock()
if prevLogIndex >= server.log.StartIndex() {
if internal {
req = server.createInternalAppendEntriesRequest(prevLogIndex)
} else {
req = server.createAppendEntriesRequest(prevLogIndex)
}
} else {
snapShotNeeded = true
}
server.log.mutex.Unlock()
p.mutex.Lock()
defer p.mutex.Unlock()
req := p.server.createInternalAppendEntriesRequest(p.prevLogIndex)
return p.sendFlushRequest(req)
if snapShotNeeded {
req := server.createSnapshotRequest()
return p.sendSnapshotRequest(req)
} else {
return p.sendFlushRequest(req)
}
}
// send Snapshot Request
func (p *Peer) sendSnapshotRequest(req *SnapshotRequest) (uint64, bool, error){
// Ignore any null requests.
if req == nil {
return 0, false, errors.New("raft.Peer: Request required")
}
// Generate an snapshot request based on the state of the server and
// log. Send the request through the user-provided handler and process the
// result.
resp, err := p.server.transporter.SendSnapshotRequest(p.server, p, req)
p.heartbeatTimer.Reset()
if resp == nil {
return 0, false, err
}
// If successful then update the previous log index. If it was
// unsuccessful then decrement the previous log index and we'll try again
// next time.
if resp.Success {
p.prevLogIndex = req.LastIndex
} else {
panic(resp)
}
return resp.Term, resp.Success, err
}
// Flushes a request through the server's transport.
@ -119,6 +177,7 @@ func (p *Peer) sendFlushRequest(req *AppendEntriesRequest) (uint64, bool, error)
// log. Send the request through the user-provided handler and process the
// result.
resp, err := p.server.transporter.SendAppendEntriesRequest(p.server, p, req)
p.heartbeatTimer.Reset()
if resp == nil {
return 0, false, err
@ -153,10 +212,11 @@ func (p *Peer) sendFlushRequest(req *AppendEntriesRequest) (uint64, bool, error)
// Listens to the heartbeat timeout and flushes an AppendEntries RPC.
func (p *Peer) heartbeatTimeoutFunc(startChannel chan bool) {
startChannel <- true
for {
// Grab the current timer channel.
p.mutex.Lock()
var c chan time.Time
if p.heartbeatTimer != nil {
c = p.heartbeatTimer.C()
@ -171,19 +231,8 @@ func (p *Peer) heartbeatTimeoutFunc(startChannel chan bool) {
// Flush the peer when we get a heartbeat timeout. If the channel is
// closed then the peer is getting cleaned up and we should exit.
if _, ok := <-c; ok {
// Retrieve the peer data within a lock that is separate from the
// server lock when creating the request. Otherwise a deadlock can
// occur.
p.mutex.Lock()
server, prevLogIndex := p.server, p.prevLogIndex
p.mutex.Unlock()
p.flush(false)
// Lock the server to create a request.
req := server.createAppendEntriesRequest(prevLogIndex)
p.mutex.Lock()
p.sendFlushRequest(req)
p.mutex.Unlock()
} else {
break
}

205
server.go
View File

@ -5,6 +5,10 @@ import (
"fmt"
"sync"
"time"
"os"
"sort"
"path"
"io/ioutil"
)
//------------------------------------------------------------------------------
@ -56,6 +60,9 @@ type Server struct {
mutex sync.Mutex
electionTimer *Timer
heartbeatTimeout time.Duration
currentSnapshot *Snapshot
lastSnapshot *Snapshot
stateMachine StateMachine
}
//------------------------------------------------------------------------------
@ -248,6 +255,12 @@ func (s *Server) Start() error {
return errors.New("raft.Server: Server already running")
}
// create snapshot dir if not exist
os.Mkdir(s.path + "/snapshot", 0700)
// ## open recovery from the newest snapShot
//s.LoadSnapshot()
// Initialize the log and load it up.
if err := s.log.Open(s.LogPath()); err != nil {
s.unload()
@ -356,6 +369,11 @@ func (s *Server) do(command Command) error {
// Capture the term that this command is executing within.
currentTerm := s.currentTerm
// // TEMP to solve the issue 18
// for _, peer := range s.peers {
// peer.pause()
// }
// Add a new entry to the log.
entry := s.log.CreateEntry(s.currentTerm, command)
if err := s.log.AppendEntry(entry); err != nil {
@ -367,18 +385,22 @@ func (s *Server) do(command Command) error {
for _, _peer := range s.peers {
peer := _peer
go func() {
term, success, err := peer.internalFlush()
term, success, err := peer.flush(true)
// Demote if we encounter a higher term.
if err != nil {
return
} else if term > currentTerm {
s.mutex.Lock()
s.setCurrentTerm(term)
if s.electionTimer != nil {
s.electionTimer.Reset()
}
s.mutex.Unlock()
return
}
@ -397,6 +419,9 @@ loop:
// If we received enough votes then stop waiting for more votes.
if responseCount >= s.QuorumSize() {
committed = true
// for _, peer := range s.peers {
// peer.resume()
// }
break
}
@ -409,6 +434,9 @@ loop:
}
responseCount++
case <-afterBetween(s.ElectionTimeout(), s.ElectionTimeout()*2):
// for _, peer := range s.peers {
// peer.resume()
// }
break loop
}
}
@ -430,7 +458,6 @@ loop:
func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
// If the server is stopped then reject it.
if !s.Running() {
return NewAppendEntriesResponse(s.currentTerm, false, 0), fmt.Errorf("raft.Server: Server stopped")
@ -449,7 +476,7 @@ func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesRespons
if s.electionTimer != nil {
s.electionTimer.Reset()
}
// Reject if log doesn't contain a matching previous entry.
if err := s.log.Truncate(req.PrevLogIndex, req.PrevLogTerm); err != nil {
return NewAppendEntriesResponse(s.currentTerm, false, s.log.CommitIndex()), err
@ -463,7 +490,7 @@ func (s *Server) AppendEntries(req *AppendEntriesRequest) (*AppendEntriesRespons
// Commit up to the commit index.
if err := s.log.SetCommitIndex(req.CommitIndex); err != nil {
return NewAppendEntriesResponse(s.currentTerm, false, s.log.CommitIndex()), err
}
}
return NewAppendEntriesResponse(s.currentTerm, true, s.log.CommitIndex()), nil
}
@ -494,6 +521,7 @@ func (s *Server) createInternalAppendEntriesRequest(prevLogIndex uint64) *Append
// server is elected then true is returned. If another server is elected then
// false is returned.
func (s *Server) promote() (bool, error) {
for {
// Start a new election.
term, lastLogIndex, lastLogTerm, err := s.promoteToCandidate()
@ -589,7 +617,6 @@ func (s *Server) promoteToCandidate() (uint64, uint64, uint64, error) {
s.leader = ""
// Pause the election timer while we're a candidate.
s.electionTimer.Pause()
// Return server state so we can check for it during leader promotion.
lastLogIndex, lastLogTerm := s.log.CommitInfo()
return s.currentTerm, lastLogIndex, lastLogTerm, nil
@ -634,7 +661,6 @@ func (s *Server) promoteToLeader(term uint64, lastLogIndex uint64, lastLogTerm u
func (s *Server) RequestVote(req *RequestVoteRequest) (*RequestVoteResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
// Fail if the server is not running.
if !s.Running() {
return NewRequestVoteResponse(s.currentTerm, false), fmt.Errorf("raft.Server: Server is stopped")
@ -725,8 +751,8 @@ func (s *Server) AddPeer(name string) error {
peer.resume()
}
s.peers[peer.name] = peer
}
}
return nil
}
@ -756,3 +782,168 @@ func (s *Server) RemovePeer(name string) error {
return nil
}
//--------------------------------------
// Log compaction
//--------------------------------------
// Creates a snapshot request.
func (s *Server) createSnapshotRequest() *SnapshotRequest {
s.mutex.Lock()
defer s.mutex.Unlock()
return NewSnapshotRequest(s.name, s.lastSnapshot)
}
// The background snapshot function
func (s *Server) Snapshot() {
for {
s.takeSnapshot()
// TODO: change this... to something reasonable
time.Sleep(5000 * time.Millisecond)
}
}
func (s *Server) takeSnapshot() error {
//TODO put a snapshot mutex
if s.currentSnapshot != nil {
return errors.New("handling snapshot")
}
lastIndex, lastTerm := s.log.CommitInfo()
if lastIndex == 0 || lastTerm == 0 {
return errors.New("No logs")
}
path := s.SnapshotPath(lastIndex, lastTerm)
state, err := s.stateMachine.Save()
if err !=nil {
return err
}
s.currentSnapshot = &Snapshot{lastIndex, lastTerm, state, path}
s.saveSnapshot()
s.log.Compact(lastIndex, lastTerm)
return nil
}
// Retrieves the log path for the server.
func (s *Server) saveSnapshot() error {
if s.currentSnapshot == nil {
return errors.New("no snapshot to save")
}
err := s.currentSnapshot.Save()
if err != nil {
return err
}
tmp := s.lastSnapshot
s.lastSnapshot = s.currentSnapshot
// delete the previous snapshot if there is any change
if tmp != nil && !(tmp.lastIndex == s.lastSnapshot.lastIndex && tmp.lastTerm == s.lastSnapshot.lastTerm) {
tmp.Remove()
}
s.currentSnapshot = nil
return nil
}
// Retrieves the log path for the server.
func (s *Server) SnapshotPath(lastIndex uint64, lastTerm uint64) string {
return path.Join(s.path, "snapshot", fmt.Sprintf("%v_%v.ss", lastTerm, lastIndex))
}
func (s *Server) SnapshotRecovery(req *SnapshotRequest) (*SnapshotResponse, error){
//
s.mutex.Lock()
defer s.mutex.Unlock()
s.stateMachine.Recovery(req.State)
//update term and index
s.currentTerm = req.LastTerm
s.log.UpdateCommitIndex(req.LastIndex)
snapshotPath := s.SnapshotPath(req.LastIndex, req.LastTerm)
s.currentSnapshot = &Snapshot{req.LastIndex, req.LastTerm, req.State, snapshotPath}
s.saveSnapshot()
s.log.Compact(req.LastIndex, req.LastTerm)
return NewSnapshotResponse(req.LastTerm, true, req.LastIndex), nil
}
// Load a snapshot at restart
func (s *Server) LoadSnapshot() error {
dir, err := os.OpenFile(path.Join(s.path, "snapshot"), os.O_RDONLY, 0)
if err != nil {
dir.Close()
panic(err)
}
filenames, err := dir.Readdirnames(-1)
if err != nil {
dir.Close()
panic(err)
}
dir.Close()
if len(filenames) == 0 {
return errors.New("no snapshot")
}
// not sure how many snapshot we should keep
sort.Strings(filenames)
snapshotPath := path.Join(s.path, "snapshot", filenames[len(filenames) - 1])
// should not file
file, err := os.OpenFile(snapshotPath, os.O_RDONLY, 0)
defer file.Close()
if err != nil {
panic(err)
}
// TODO check checksum first
// TODO recovery state machine
var state []byte
var checksum, lastIndex, lastTerm uint64
n , err := fmt.Fscanf(file, "%08x\n%v\n%v", &checksum, &lastIndex, &lastTerm)
if err != nil {
return err
}
if n != 3 {
return errors.New("Bad snapshot file")
}
state, _ = ioutil.ReadAll(file)
if err != nil {
return err
}
s.lastSnapshot = &Snapshot{lastIndex, lastTerm, state, snapshotPath}
err = s.stateMachine.Recovery(state)
s.log.SetStartTerm(lastTerm)
s.log.SetStartIndex(lastIndex)
s.log.UpdateCommitIndex(lastIndex)
return err
}

View File

@ -389,16 +389,12 @@ func TestServerMultiNode(t *testing.T) {
time.Sleep(100 * time.Millisecond)
leader.Stop()
time.Sleep(100 * time.Millisecond)
// Check that either server 2 or 3 is the leader now.
mutex.Lock()
if servers["2"].State() != Leader && servers["3"].State() != Leader {
t.Fatalf("Expected leader re-election: 2=%v, 3=%v", servers["2"].state, servers["3"].state)
t.Fatalf("Expected leader re-election: 2=%v, 3=%v\n", servers["2"].state, servers["3"].state)
}
mutex.Unlock()
// Stop the servers.
for _, server := range servers {
server.Stop()
}
}

68
snapshot.go Normal file
View File

@ -0,0 +1,68 @@
package raft
import (
"hash/crc32"
"fmt"
"syscall"
"bytes"
"os"
)
//------------------------------------------------------------------------------
//
// Typedefs
//
//------------------------------------------------------------------------------
// the in memory SnapShot struct
// TODO add cluster configuration
type Snapshot struct {
lastIndex uint64
lastTerm uint64
// cluster configuration.
state []byte
path string
}
// Save the snapshot to a file
func (ss *Snapshot) Save() error {
// Write machine state to temporary buffer.
var b bytes.Buffer
if _, err := fmt.Fprintf(&b, "%v", 2); err != nil {
return err
}
// Generate checksum.
checksum := crc32.ChecksumIEEE(b.Bytes())
// open file
file, err := os.OpenFile(ss.path, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer file.Close()
// Write snapshot with checksum.
if _, err = fmt.Fprintf(file, "%08x\n%v\n%v\n", checksum, ss.lastIndex,
ss.lastTerm); err != nil {
return err
}
if _, err = file.Write(ss.state); err != nil {
return err
}
// force the change writting to disk
syscall.Fsync(int(file.Fd()))
return err
}
// remove the file of the snapshot
func (ss *Snapshot) Remove() error {
err := os.Remove(ss.path)
return err
}

41
snapshot_request.go Normal file
View File

@ -0,0 +1,41 @@
package raft
// The request sent to a server to start from the snapshot.
type SnapshotRequest struct {
LeaderName string `json:"leaderName"`
LastIndex uint64 `json:"lastTerm"`
LastTerm uint64 `json:"lastIndex"`
State []byte `json:"state"`
}
// The response returned from a server appending entries to the log.
type SnapshotResponse struct {
Term uint64 `json:"term"`
Success bool `json:"success"`
CommitIndex uint64 `json:"commitIndex"`
}
//------------------------------------------------------------------------------
//
// Constructors
//
//------------------------------------------------------------------------------
// Creates a new Snapshot request.
func NewSnapshotRequest(leaderName string, snapshot *Snapshot) *SnapshotRequest {
return &SnapshotRequest{
LeaderName: leaderName,
LastIndex: snapshot.lastIndex,
LastTerm: snapshot.lastTerm,
State: snapshot.state,
}
}
// Creates a new Snapshot response.
func NewSnapshotResponse(term uint64, success bool, commitIndex uint64) *SnapshotResponse {
return &SnapshotResponse{
Term: term,
Success: success,
CommitIndex: commitIndex,
}
}

212
snapshot_test.go Normal file
View File

@ -0,0 +1,212 @@
package raft
import (
"sync"
"testing"
"time"
"bytes"
)
// test take and send snapshot
func TestTakeAndSendSnapshot(t *testing.T) {
// Initialize the servers.
var mutex sync.Mutex
//fmt.Println("---Snapshot Test---")
names := []string{"1", "2", "3"}
servers := map[string]*Server{}
for _, server := range servers {
defer server.Stop()
}
transporter := &testTransporter{}
transporter.sendVoteRequestFunc = func(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error) {
mutex.Lock()
s := servers[peer.name]
mutex.Unlock()
resp, err := s.RequestVote(req)
return resp, err
}
transporter.sendAppendEntriesRequestFunc = func(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error) {
mutex.Lock()
s := servers[peer.name]
mutex.Unlock()
resp, err := s.AppendEntries(req)
return resp, err
}
transporter.sendSnapshotRequestFunc = func(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error) {
mutex.Lock()
s := servers[peer.name]
mutex.Unlock()
resp, err := s.SnapshotRecovery(req)
return resp, err
}
stateMachine := &testStateMachine{}
stateMachine.saveFunc = func() ([]byte,error) {
return []byte{0x8},nil
}
stateMachine.recoveryFunc = func(state []byte) error {
return nil
}
var leader *Server
for _, name := range names {
server := newTestServer(name, transporter)
server.stateMachine = stateMachine
server.SetElectionTimeout(testElectionTimeout)
server.SetHeartbeatTimeout(testHeartbeatTimeout)
if err := server.Start(); err != nil {
t.Fatalf("Unable to start server[%s]: %v", name, err)
}
if name == "1" {
leader = server
if err := server.Initialize(); err != nil {
t.Fatalf("Unable to initialize server[%s]: %v", name, err)
}
}
if err := leader.Do(&joinCommand{Name:name}); err != nil {
t.Fatalf("Unable to join server[%s]: %v", name, err)
}
mutex.Lock()
servers[name] = server
mutex.Unlock()
}
time.Sleep(100 * time.Millisecond)
// Check that two peers exist on leader.
mutex.Lock()
if leader.MemberCount() != 3 {
t.Fatalf("Expected member count to be 3, got %v", leader.MemberCount())
}
mutex.Unlock()
// commit single entry.
err := leader.Do(&TestCommand1{"foo", 10})
if err != nil {
t.Fatal(err)
}
index, term := leader.log.CommitInfo()
// three join and one test Command
if !(index == 4 && term == 1) {
t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term)
}
leader.takeSnapshot()
logLen := len(leader.log.entries)
if logLen != 0 {
t.Fatalf("Invalid logLen [Len=%v]", logLen)
}
if leader.log.startIndex != 4 || leader.log.startTerm != 1 {
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
leader.log.startIndex, leader.log.startTerm)
}
// test send snapshot to a new node
// send from heartbeat
newServer := newTestServer("4", transporter)
newServer.stateMachine = stateMachine
if err := newServer.Start(); err != nil {
t.Fatalf("Unable to start server[4]: %v", err)
}
if err := leader.Do(&joinCommand{Name:"4"}); err != nil {
t.Fatalf("Unable to join server[4]: %v", err)
}
mutex.Lock()
servers["4"] = newServer
mutex.Unlock()
// wait for heartbeat :P
time.Sleep(100 * time.Millisecond)
if leader.log.startIndex != 4 || leader.log.startTerm != 1 {
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
leader.log.startIndex, leader.log.startTerm)
}
time.Sleep(100 * time.Millisecond)
}
func TestStartFormSnapshot(t *testing.T) {
server := newTestServer("1", &testTransporter{})
stateMachine := &testStateMachine{}
stateMachine.saveFunc = func() ([]byte,error) {
return []byte{0x60,0x61,0x62,0x63,0x64,0x65},nil
}
stateMachine.recoveryFunc = func(state []byte) error {
expect := []byte{0x60,0x61,0x62,0x63,0x64,0x65}
if !(bytes.Equal(state, expect)) {
t.Fatalf("Invalid State [Expcet=%v, Actual=%v]", expect, state)
}
return nil
}
server.stateMachine = stateMachine
oldPath := server.path
server.Start()
server.Initialize()
// commit single entry.
err := server.Do(&TestCommand1{"foo", 10})
if err != nil {
t.Fatal(err)
}
server.takeSnapshot()
logLen := len(server.log.entries)
if logLen != 0 {
t.Fatalf("Invalid logLen [Len=%v]", logLen)
}
if server.log.startIndex != 1 || server.log.startTerm != 1 {
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
server.log.startIndex, server.log.startTerm)
}
server.Stop()
server = newTestServer("1", &testTransporter{})
server.stateMachine = stateMachine
// reset the oldPath
server.path = oldPath
server.Start()
logLen = len(server.log.entries)
if logLen != 0 {
t.Fatalf("Invalid logLen [Len=%v]", logLen)
}
if index, term := server.log.CommitInfo(); !(index == 0 && term == 0) {
t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term)
}
if server.log.startIndex != 0 || server.log.startTerm != 0 {
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
server.log.startIndex, server.log.startTerm)
}
server.LoadSnapshot()
if server.log.startIndex != 1 || server.log.startTerm != 1 {
t.Fatalf("Invalid log info [StartIndex=%v, StartTERM=%v]",
server.log.startIndex, server.log.startTerm)
}
}

14
statemachine.go Normal file
View File

@ -0,0 +1,14 @@
package raft
//------------------------------------------------------------------------------
//
// Typedefs
//
//------------------------------------------------------------------------------
// StateMachine is the interface for allowing the host application to save and
// recovery the state machine
type StateMachine interface {
Save() ([]byte, error)
Recovery([]byte) error
}

21
test.go
View File

@ -4,7 +4,7 @@ import (
"fmt"
"io/ioutil"
"os"
"time"
"time"
)
const (
@ -98,6 +98,7 @@ func newTestCluster(names []string, transporter Transporter, lookup map[string]*
type testTransporter struct {
sendVoteRequestFunc func(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error)
sendAppendEntriesRequestFunc func(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error)
sendSnapshotRequestFunc func(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error)
}
func (t *testTransporter) SendVoteRequest(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error) {
@ -108,6 +109,24 @@ func (t *testTransporter) SendAppendEntriesRequest(server *Server, peer *Peer, r
return t.sendAppendEntriesRequestFunc(server, peer, req)
}
func (t *testTransporter) SendSnapshotRequest(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error) {
return t.sendSnapshotRequestFunc(server, peer, req)
}
type testStateMachine struct {
saveFunc func() ([]byte, error)
recoveryFunc func([]byte) error
}
func (sm *testStateMachine) Save() ([]byte, error) {
return sm.saveFunc()
}
func (sm *testStateMachine) Recovery(state []byte) error {
return sm.recoveryFunc(state)
}
//--------------------------------------
// Join Command

View File

@ -11,4 +11,5 @@ package raft
type Transporter interface {
SendVoteRequest(server *Server, peer *Peer, req *RequestVoteRequest) (*RequestVoteResponse, error)
SendAppendEntriesRequest(server *Server, peer *Peer, req *AppendEntriesRequest) (*AppendEntriesResponse, error)
SendSnapshotRequest(server *Server, peer *Peer, req *SnapshotRequest) (*SnapshotResponse, error)
}