check checksum before load snapshot

pull/820/head
Xiang Li 2013-06-30 17:55:54 -07:00
parent b0eaf972e6
commit 4f24fb775f
2 changed files with 28 additions and 7 deletions

View File

@ -4,10 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
"sort"
"strconv"
"sync"
"time"
)
@ -1048,9 +1050,9 @@ func (s *Server) LoadSnapshot() error {
// TODO check checksum first
var snapshotBytes []byte
var checksum []byte
var checksumByte []byte
n, err := fmt.Fscanf(file, "%08x\n", &checksum)
n, err := fmt.Fscanf(file, "%v\n", &checksumByte)
if err != nil {
return err
@ -1060,17 +1062,36 @@ func (s *Server) LoadSnapshot() error {
return errors.New("Bad snapshot file")
}
snapshotBytes, _ = ioutil.ReadAll(file)
debugln(string(snapshotBytes))
err = json.Unmarshal(snapshotBytes, &s.lastSnapshot)
checksum, err := strconv.ParseUint(string(checksumByte), 10, 32)
if err != nil {
return err
}
snapshotBytes, _ = ioutil.ReadAll(file)
debugln(string(snapshotBytes))
// Generate checksum.
byteChecksum := crc32.ChecksumIEEE(snapshotBytes)
if uint32(checksum) != byteChecksum {
fmt.Println(checksum, " ", byteChecksum)
return errors.New("bad snapshot file")
}
err = json.Unmarshal(snapshotBytes, &s.lastSnapshot)
if err != nil {
fmt.Println("unmarshal error")
return err
}
err = s.stateMachine.Recovery(s.lastSnapshot.State)
if err != nil {
fmt.Println("recovery error")
}
for _, peerName := range s.lastSnapshot.Peers {
s.AddPeer(peerName)
}

View File

@ -45,7 +45,7 @@ func (ss *Snapshot) Save() error {
checksum := crc32.ChecksumIEEE(b)
// Write snapshot with checksum.
if _, err = fmt.Fprintf(file, "%08x\n", checksum); err != nil {
if _, err = fmt.Fprintf(file, "%v\n", checksum); err != nil {
return err
}