diff --git a/append_entries_request_test.go b/append_entries_request_test.go index 97107022ed..ef6732fc46 100644 --- a/append_entries_request_test.go +++ b/append_entries_request_test.go @@ -27,7 +27,7 @@ func BenchmarkAppendEntriesRequestDecoding(b *testing.B) { func createTestAppendEntriesRequest(entryCount int) (*AppendEntriesRequest, []byte) { entries := make([]*LogEntry, 0) for i := 0; i < entryCount; i++ { - command := &JoinCommand{Name: "localhost:1000"} + command := &DefaultJoinCommand{Name: "localhost:1000"} entry, _ := newLogEntry(nil, 1, 2, command) entries = append(entries, entry) } diff --git a/http_transporter_test.go b/http_transporter_test.go index d869509a8b..1334a91060 100644 --- a/http_transporter_test.go +++ b/http_transporter_test.go @@ -68,7 +68,7 @@ func runTestHttpServers(t *testing.T, servers *[]*Server, transporter *HTTPTrans // Setup configuration. for _, server := range *servers { - if _, err := (*servers)[0].Do(&JoinCommand{Name: server.Name()}); err != nil { + if _, err := (*servers)[0].Do(&DefaultJoinCommand{Name: server.Name()}); err != nil { t.Fatalf("Server %s unable to join: %v", server.Name(), err) } } diff --git a/join_command.go b/join_command.go index d960f01d8a..74e14239db 100644 --- a/join_command.go +++ b/join_command.go @@ -1,17 +1,28 @@ package raft +// Join command interface +type JoinCommand interface { + CommandName() string + Apply(server *Server) (interface{}, error) + NodeName() string +} + // Join command -type JoinCommand struct { +type DefaultJoinCommand struct { Name string `json:"name"` } // The name of the Join command in the log -func (c *JoinCommand) CommandName() string { +func (c *DefaultJoinCommand) CommandName() string { return "raft:join" } -func (c *JoinCommand) Apply(server *Server) (interface{}, error) { +func (c *DefaultJoinCommand) Apply(server *Server) (interface{}, error) { err := server.AddPeer(c.Name) return []byte("join"), err } + +func (c *DefaultJoinCommand) NodeName() string { + return c.Name +} diff --git a/server.go b/server.go index c3a13e6757..ca0e49b49a 100644 --- a/server.go +++ b/server.go @@ -313,7 +313,7 @@ func (s *Server) SetHeartbeatTimeout(duration time.Duration) { // Reg the NOPCommand func init() { RegisterCommand(&NOPCommand{}) - RegisterCommand(&JoinCommand{}) + RegisterCommand(&DefaultJoinCommand{}) RegisterCommand(&LeaveCommand{}) } @@ -441,7 +441,7 @@ func (s *Server) setCurrentTerm(term uint64, leaderName string, append bool) { return } - // discover new leader when candidate + // discover new leader when candidate // save leader name when follower if term == s.currentTerm && s.state != Leader && append { s.state = Follower @@ -524,19 +524,13 @@ func (s *Server) followerLoop() { case e := <-s.c: if e.target == &stopValue { s.setState(Stopped) - } else if command, ok := e.target.(Command); ok { - - if command, ok := command.(*JoinCommand); ok { - - //If no log entries exist and a self-join command is issued - //then immediately become leader and commit entry. - if s.log.currentIndex() == 0 && command.Name == s.Name() { - s.debugln("selfjoin and promote to leader") - s.setState(Leader) - s.processCommand(command, e) - } else { - err = NotLeaderError - } + } else if command, ok := e.target.(JoinCommand); ok { + //If no log entries exist and a self-join command is issued + //then immediately become leader and commit entry. + if s.log.currentIndex() == 0 && command.NodeName() == s.Name() { + s.debugln("selfjoin and promote to leader") + s.setState(Leader) + s.processCommand(command, e) } else { err = NotLeaderError } @@ -546,6 +540,8 @@ func (s *Server) followerLoop() { e.returnValue, update = s.processRequestVoteRequest(req) } else if req, ok := e.target.(*SnapshotRequest); ok { e.returnValue = s.processSnapshotRequest(req) + } else { + err = NotLeaderError } // Callback to event. diff --git a/server_test.go b/server_test.go index 14bc36f468..829b5a3030 100644 --- a/server_test.go +++ b/server_test.go @@ -24,7 +24,7 @@ func TestServerRequestVote(t *testing.T) { server := newTestServer("1", &testTransporter{}) server.Start() - if _, err := server.Do(&JoinCommand{Name: server.Name()}); err != nil { + if _, err := server.Do(&DefaultJoinCommand{Name: server.Name()}); err != nil { t.Fatalf("Server %s unable to join: %v", server.Name(), err) } @@ -40,7 +40,7 @@ func TestServerRequestVoteDeniedForStaleTerm(t *testing.T) { server := newTestServer("1", &testTransporter{}) server.Start() - if _, err := server.Do(&JoinCommand{Name: server.Name()}); err != nil { + if _, err := server.Do(&DefaultJoinCommand{Name: server.Name()}); err != nil { t.Fatalf("Server %s unable to join: %v", server.Name(), err) } @@ -60,7 +60,7 @@ func TestServerRequestVoteDeniedIfAlreadyVoted(t *testing.T) { server := newTestServer("1", &testTransporter{}) server.Start() - if _, err := server.Do(&JoinCommand{Name: server.Name()}); err != nil { + if _, err := server.Do(&DefaultJoinCommand{Name: server.Name()}); err != nil { t.Fatalf("Server %s unable to join: %v", server.Name(), err) } @@ -81,7 +81,7 @@ func TestServerRequestVoteApprovedIfAlreadyVotedInOlderTerm(t *testing.T) { server := newTestServer("1", &testTransporter{}) server.Start() - if _, err := server.Do(&JoinCommand{Name: server.Name()}); err != nil { + if _, err := server.Do(&DefaultJoinCommand{Name: server.Name()}); err != nil { t.Fatalf("Server %s unable to join: %v", server.Name(), err) } @@ -331,7 +331,7 @@ func TestServerSingleNode(t *testing.T) { time.Sleep(50 * time.Millisecond) // Join the server to itself. - if _, err := server.Do(&JoinCommand{Name: "1"}); err != nil { + if _, err := server.Do(&DefaultJoinCommand{Name: "1"}); err != nil { t.Fatalf("Unable to join: %v", err) } debugln("finish command") @@ -403,7 +403,7 @@ func TestServerMultiNode(t *testing.T) { server.Start() time.Sleep(10 * time.Millisecond) } - if _, err := leader.Do(&JoinCommand{Name: name}); err != nil { + if _, err := leader.Do(&DefaultJoinCommand{Name: name}); err != nil { t.Fatalf("Unable to join server[%s]: %v", name, err) }