diff --git a/command.go b/command.go index 3578b471ae..b1f73d140f 100644 --- a/command.go +++ b/command.go @@ -9,4 +9,6 @@ package raft // A command represents an action to be taken on the replicated state machine. type Command interface { CommandName() string + Validate(server *Server) error + Apply(server *Server) } diff --git a/join_command.go b/join_command.go index 6c751bdd25..68d3d9a7bf 100644 --- a/join_command.go +++ b/join_command.go @@ -45,10 +45,9 @@ func (c *JoinCommand) Validate(server *Server) error { } // Updates the state machine to join the server to the cluster. -func (c *JoinCommand) Apply(server *Server) error { +func (c *JoinCommand) Apply(server *Server) { if server.name != c.Name { peer := &Peer{name: c.Name} server.peers[peer.name] = peer } - return nil } diff --git a/log.go b/log.go index 8d847ab3c8..ec9418ebb2 100644 --- a/log.go +++ b/log.go @@ -18,6 +18,7 @@ import ( // A log is a collection of log entries that are persisted to durable storage. type Log struct { + ApplyFunc func(Command) file *os.File entries []*LogEntry commitIndex uint64 @@ -204,28 +205,57 @@ func (l *Log) CreateEntry(term uint64, command Command) *LogEntry { return NewLogEntry(l, l.NextIndex(), term, command) } -// Updates the commit index and writes entries after that index to the stable -// storage. +//-------------------------------------- +// Commit +//-------------------------------------- + +// Retrieves the last index and term that has been committed to the log. +func (l *Log) CommitInfo() (index uint64, term uint64) { + l.mutex.Lock() + defer l.mutex.Unlock() + + // If we don't have any entries then just return zeros. + if l.commitIndex == 0 { + return 0, 0 + } + + // Return the last index & term from the last committed entry. + lastCommitEntry := l.entries[l.commitIndex-1] + return lastCommitEntry.index, lastCommitEntry.term +} + +// Updates the commit index and writes entries after that index to the stable storage. func (l *Log) SetCommitIndex(index uint64) error { l.mutex.Lock() defer l.mutex.Unlock() + // Panic if we don't have any way to apply commands. + if l.ApplyFunc == nil { + panic("raft.Log: Apply function not set") + } + // Do not allow previous indices to be committed again. if index < l.commitIndex { return fmt.Errorf("raft.Log: Commit index (%d) ahead of requested commit index (%d)", l.commitIndex, index) } + if index > uint64(len(l.entries)) { + return fmt.Errorf("raft.Log: Commit index (%d) out of range (%d)", index, len(l.entries)) + } // Find all entries whose index is between the previous index and the current index. - for _, entry := range l.entries { - if entry.index > l.commitIndex && entry.index <= index { - // Write to storage. - if err := entry.Encode(l.file); err != nil { - return err - } + for i := l.commitIndex + 1; i <= index; i++ { + entry := l.entries[i-1] - // Update commit index. - l.commitIndex = entry.index + // Write to storage. + if err := entry.Encode(l.file); err != nil { + return err } + + // Apply the changes to the state machine. + l.ApplyFunc(entry.command) + + // Update commit index. + l.commitIndex = entry.index } return nil diff --git a/log_test.go b/log_test.go index bff6ca20f9..0f737555b6 100644 --- a/log_test.go +++ b/log_test.go @@ -17,6 +17,7 @@ import ( func TestLogNewLog(t *testing.T) { path := getLogPath() log := NewLog() + log.ApplyFunc = func(c Command) {} log.AddCommandType(&TestCommand1{}) log.AddCommandType(&TestCommand2{}) if err := log.Open(path); err != nil { @@ -46,6 +47,9 @@ func TestLogNewLog(t *testing.T) { if string(actual) != expected { t.Fatalf("Unexpected buffer:\nexp:\n%s\ngot:\n%s", expected, string(actual)) } + if index, term := log.CommitInfo(); index != 2 || term != 1 { + t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term) + } // Full commit. if err := log.SetCommitIndex(3); err != nil { @@ -59,6 +63,9 @@ func TestLogNewLog(t *testing.T) { if string(actual) != expected { t.Fatalf("Unexpected buffer:\nexp:\n%s\ngot:\n%s", expected, string(actual)) } + if index, term := log.CommitInfo(); index != 3 || term != 2 { + t.Fatalf("Invalid commit info [IDX=%v, TERM=%v]", index, term) + } } // Ensure that we can decode and encode to an existing log. @@ -100,6 +107,7 @@ func TestLogRecovery(t *testing.T) { `4c08d91f 0000000000000002 0000000000000001 cmd_2 {"x":100}` + "\n" + `6ac5807c 0000000000000003 00000000000`) log := NewLog() + log.ApplyFunc = func(c Command) {} log.AddCommandType(&TestCommand1{}) log.AddCommandType(&TestCommand2{}) if err := log.Open(path); err != nil { diff --git a/server.go b/server.go index c877611af7..5720867da1 100644 --- a/server.go +++ b/server.go @@ -43,7 +43,7 @@ type Server struct { leader *Peer peers map[string]*Peer mutex sync.Mutex - ElectionTimeout int + electionTimer *ElectionTimer DoHandler func(*Server, *Peer, Command) error AppendEntriesHandler func(*Server, *AppendEntriesRequest) (*AppendEntriesResponse, error) } @@ -315,7 +315,7 @@ func (s *Server) RequestVote(req *RequestVoteRequest) *RequestVoteResponse { if req.Term > s.currentTerm { s.currentTerm = req.Term s.votedFor = "" - s.resign() + s.state = Follower } // If we've already voted for a different candidate then don't vote for this candidate. @@ -325,7 +325,8 @@ func (s *Server) RequestVote(req *RequestVoteRequest) *RequestVoteResponse { // If the candidate's log is not at least as up-to-date as our committed log then don't vote. /* - if s.log.CommitIndex() > req.LastLogIndex || s.log.CommitTerm() > req.LastLogTerm { + lastCommitIndex, lastCommitTerm := s.log.LastCommitInfo() + if lastCommitIndex > req.LastLogIndex || lastCommitTerm > req.LastLogTerm { return NewRequestVoteResponse(s.currentTerm, false) } @@ -336,11 +337,6 @@ func (s *Server) RequestVote(req *RequestVoteRequest) *RequestVoteResponse { return NewRequestVoteResponse(s.currentTerm, true) } -// Resign the server to a follower if the server is a candidate or leader. -func (s *Server) resign() { - s.state = Follower -} - //-------------------------------------- // Membership //-------------------------------------- diff --git a/test.go b/test.go index bb88374171..8fb5007f5c 100644 --- a/test.go +++ b/test.go @@ -54,6 +54,13 @@ func (c TestCommand1) CommandName() string { return "cmd_1" } +func (c TestCommand1) Validate(server *Server) error { + return nil +} + +func (c TestCommand1) Apply(server *Server) { +} + //-------------------------------------- // Command2 //-------------------------------------- @@ -65,3 +72,10 @@ type TestCommand2 struct { func (c TestCommand2) CommandName() string { return "cmd_2" } + +func (c TestCommand2) Validate(server *Server) error { + return nil +} + +func (c TestCommand2) Apply(server *Server) { +}