Prevent truncated WAL entries from panicing

pull/6483/head
Jason Wilder 2016-04-26 15:57:55 -06:00
parent 0de21ade40
commit 23bbfb2192
2 changed files with 136 additions and 2 deletions

View File

@ -56,7 +56,10 @@ const (
DeleteRangeWALEntryType WalEntryType = 0x03 DeleteRangeWALEntryType WalEntryType = 0x03
) )
var ErrWALClosed = fmt.Errorf("WAL closed") var (
ErrWALClosed = fmt.Errorf("WAL closed")
ErrWALCorrupt = fmt.Errorf("corrupted WAL entry")
)
// Statistics gathered by the WAL. // Statistics gathered by the WAL.
const ( const (
@ -578,11 +581,24 @@ func (w *WriteWALEntry) UnmarshalBinary(b []byte) error {
typ := b[i] typ := b[i]
i++ i++
if i+2 > len(b) {
return ErrWALCorrupt
}
length := int(binary.BigEndian.Uint16(b[i : i+2])) length := int(binary.BigEndian.Uint16(b[i : i+2]))
i += 2 i += 2
if i+length > len(b) {
return ErrWALCorrupt
}
k := string(b[i : i+length]) k := string(b[i : i+length])
i += length i += length
if i+4 > len(b) {
return ErrWALCorrupt
}
nvals := int(binary.BigEndian.Uint32(b[i : i+4])) nvals := int(binary.BigEndian.Uint32(b[i : i+4]))
i += 4 i += 4
@ -610,11 +626,19 @@ func (w *WriteWALEntry) UnmarshalBinary(b []byte) error {
} }
for j := 0; j < nvals; j++ { for j := 0; j < nvals; j++ {
if i+8 > len(b) {
return ErrWALCorrupt
}
un := int64(binary.BigEndian.Uint64(b[i : i+8])) un := int64(binary.BigEndian.Uint64(b[i : i+8]))
i += 8 i += 8
switch typ { switch typ {
case float64EntryType: case float64EntryType:
if i+8 > len(b) {
return ErrWALCorrupt
}
v := math.Float64frombits((binary.BigEndian.Uint64(b[i : i+8]))) v := math.Float64frombits((binary.BigEndian.Uint64(b[i : i+8])))
i += 8 i += 8
if fv, ok := values[j].(*FloatValue); ok { if fv, ok := values[j].(*FloatValue); ok {
@ -622,6 +646,10 @@ func (w *WriteWALEntry) UnmarshalBinary(b []byte) error {
fv.value = v fv.value = v
} }
case integerEntryType: case integerEntryType:
if i+8 > len(b) {
return ErrWALCorrupt
}
v := int64(binary.BigEndian.Uint64(b[i : i+8])) v := int64(binary.BigEndian.Uint64(b[i : i+8]))
i += 8 i += 8
if fv, ok := values[j].(*IntegerValue); ok { if fv, ok := values[j].(*IntegerValue); ok {
@ -629,6 +657,10 @@ func (w *WriteWALEntry) UnmarshalBinary(b []byte) error {
fv.value = v fv.value = v
} }
case booleanEntryType: case booleanEntryType:
if i >= len(b) {
return ErrWALCorrupt
}
v := b[i] v := b[i]
i += 1 i += 1
if fv, ok := values[j].(*BooleanValue); ok { if fv, ok := values[j].(*BooleanValue); ok {
@ -640,12 +672,21 @@ func (w *WriteWALEntry) UnmarshalBinary(b []byte) error {
} }
} }
case stringEntryType: case stringEntryType:
if i+4 > len(b) {
return ErrWALCorrupt
}
length := int(binary.BigEndian.Uint32(b[i : i+4])) length := int(binary.BigEndian.Uint32(b[i : i+4]))
if i+length > int(uint32(len(b))) { if i+length > int(uint32(len(b))) {
return fmt.Errorf("corrupted write wall entry") return ErrWALCorrupt
} }
i += 4 i += 4
if i+length > len(b) {
return ErrWALCorrupt
}
v := string(b[i : i+length]) v := string(b[i : i+length])
i += length i += length
if fv, ok := values[j].(*StringValue); ok { if fv, ok := values[j].(*StringValue); ok {
@ -713,13 +754,24 @@ func (w *DeleteRangeWALEntry) MarshalBinary() ([]byte, error) {
} }
func (w *DeleteRangeWALEntry) UnmarshalBinary(b []byte) error { func (w *DeleteRangeWALEntry) UnmarshalBinary(b []byte) error {
if len(b) < 16 {
return ErrWALCorrupt
}
w.Min = int64(binary.BigEndian.Uint64(b[:8])) w.Min = int64(binary.BigEndian.Uint64(b[:8]))
w.Max = int64(binary.BigEndian.Uint64(b[8:16])) w.Max = int64(binary.BigEndian.Uint64(b[8:16]))
i := 16 i := 16
for i < len(b) { for i < len(b) {
if i+4 > len(b) {
return ErrWALCorrupt
}
sz := int(binary.BigEndian.Uint32(b[i : i+4])) sz := int(binary.BigEndian.Uint32(b[i : i+4]))
i += 4 i += 4
if i+sz > len(b) {
return ErrWALCorrupt
}
w.Keys = append(w.Keys, string(b[i:i+sz])) w.Keys = append(w.Keys, string(b[i:i+sz]))
i += sz i += sz
} }

View File

@ -5,6 +5,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/davecgh/go-spew/spew"
"github.com/influxdata/influxdb/tsdb/engine/tsm1" "github.com/influxdata/influxdb/tsdb/engine/tsm1"
"github.com/golang/snappy" "github.com/golang/snappy"
@ -566,6 +567,87 @@ func TestWALWriter_Corrupt(t *testing.T) {
} }
} }
func TestWriteWALSegment_UnmarshalBinary_WriteWALCorrupt(t *testing.T) {
p1 := tsm1.NewValue(1, 1.1)
p2 := tsm1.NewValue(1, int64(1))
p3 := tsm1.NewValue(1, true)
p4 := tsm1.NewValue(1, "string")
values := map[string][]tsm1.Value{
"cpu,host=A#!~#float": []tsm1.Value{p1, p1},
"cpu,host=A#!~#int": []tsm1.Value{p2, p2},
"cpu,host=A#!~#bool": []tsm1.Value{p3, p3},
"cpu,host=A#!~#string": []tsm1.Value{p4, p4},
}
w := &tsm1.WriteWALEntry{
Values: values,
}
b, err := w.MarshalBinary()
if err != nil {
t.Fatalf("unexpected error, got %v", err)
}
// Test every possible truncation of a write WAL entry
for i := 0; i < len(b); i++ {
// re-allocated to ensure capacity would be exceed if slicing
truncated := make([]byte, i)
copy(truncated, b[:i])
err := w.UnmarshalBinary(truncated)
if err != nil && err != tsm1.ErrWALCorrupt {
t.Fatalf("unexpected error: %v", err)
}
}
}
func TestWriteWALSegment_UnmarshalBinary_DeleteWALCorrupt(t *testing.T) {
w := &tsm1.DeleteWALEntry{
Keys: []string{"foo", "bar"},
}
b, err := w.MarshalBinary()
if err != nil {
t.Fatalf("unexpected error, got %v", err)
}
// Test every possible truncation of a write WAL entry
for i := 0; i < len(b); i++ {
// re-allocated to ensure capacity would be exceed if slicing
truncated := make([]byte, i)
copy(truncated, b[:i])
err := w.UnmarshalBinary(truncated)
if err != nil && err != tsm1.ErrWALCorrupt {
t.Fatalf("unexpected error: %v", err)
}
}
}
func TestWriteWALSegment_UnmarshalBinary_DeleteRangeWALCorrupt(t *testing.T) {
w := &tsm1.DeleteRangeWALEntry{
Keys: []string{"foo", "bar"},
Min: 1,
Max: 2,
}
b, err := w.MarshalBinary()
if err != nil {
t.Fatalf("unexpected error, got %v", err)
}
// Test every possible truncation of a write WAL entry
for i := 0; i < len(b); i++ {
// re-allocated to ensure capacity would be exceed if slicing
truncated := make([]byte, i)
copy(truncated, b[:i])
spew.Dump(truncated)
err := w.UnmarshalBinary(truncated)
if err != nil && err != tsm1.ErrWALCorrupt {
t.Fatalf("unexpected error: %v", err)
}
}
}
func BenchmarkWALSegmentWriter(b *testing.B) { func BenchmarkWALSegmentWriter(b *testing.B) {
points := map[string][]tsm1.Value{} points := map[string][]tsm1.Value{}
for i := 0; i < 5000; i++ { for i := 0; i < 5000; i++ {