check checksum before load snapshot
parent
b0eaf972e6
commit
4f24fb775f
33
server.go
33
server.go
|
@ -4,10 +4,12 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -1048,9 +1050,9 @@ func (s *Server) LoadSnapshot() error {
|
|||
// TODO check checksum first
|
||||
|
||||
var snapshotBytes []byte
|
||||
var checksum []byte
|
||||
var checksumByte []byte
|
||||
|
||||
n, err := fmt.Fscanf(file, "%08x\n", &checksum)
|
||||
n, err := fmt.Fscanf(file, "%v\n", &checksumByte)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1060,17 +1062,36 @@ func (s *Server) LoadSnapshot() error {
|
|||
return errors.New("Bad snapshot file")
|
||||
}
|
||||
|
||||
snapshotBytes, _ = ioutil.ReadAll(file)
|
||||
debugln(string(snapshotBytes))
|
||||
|
||||
err = json.Unmarshal(snapshotBytes, &s.lastSnapshot)
|
||||
checksum, err := strconv.ParseUint(string(checksumByte), 10, 32)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
snapshotBytes, _ = ioutil.ReadAll(file)
|
||||
debugln(string(snapshotBytes))
|
||||
|
||||
// Generate checksum.
|
||||
byteChecksum := crc32.ChecksumIEEE(snapshotBytes)
|
||||
|
||||
if uint32(checksum) != byteChecksum {
|
||||
fmt.Println(checksum, " ", byteChecksum)
|
||||
return errors.New("bad snapshot file")
|
||||
}
|
||||
|
||||
err = json.Unmarshal(snapshotBytes, &s.lastSnapshot)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("unmarshal error")
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.stateMachine.Recovery(s.lastSnapshot.State)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("recovery error")
|
||||
}
|
||||
|
||||
for _, peerName := range s.lastSnapshot.Peers {
|
||||
s.AddPeer(peerName)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func (ss *Snapshot) Save() error {
|
|||
checksum := crc32.ChecksumIEEE(b)
|
||||
|
||||
// Write snapshot with checksum.
|
||||
if _, err = fmt.Fprintf(file, "%08x\n", checksum); err != nil {
|
||||
if _, err = fmt.Fprintf(file, "%v\n", checksum); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue