diff --git a/services/snapshotter/service.go b/services/snapshotter/service.go index ce3b34ca51..54a243c5b4 100644 --- a/services/snapshotter/service.go +++ b/services/snapshotter/service.go @@ -374,40 +374,55 @@ func (s *Service) writeRetentionPolicyInfo(conn net.Conn, database, retentionPol return nil } -// readRequest unmarshals a request object from the conn. -func (s *Service) readRequest(conn net.Conn) (Request, []byte, error) { - var r Request - d := json.NewDecoder(conn) - - if err := d.Decode(&r); err != nil { - return r, nil, err +// readRequest unmarshals a request object and payload from an io.Reader. +// +// we check if UploadSize is less than or equal to zero because it is a signed +// int64. this prevents any kind of buffer overflow that might result from +// having a negative value. +// +// then we read the remainder of the payload up to r.UploadSize+1. +// +// we use r.UploadSize+1 because the end of the json message always contains a +// newline which r.UploadSize doesn't account for and we always discard. +// +// FIXME: there is a bug lurking here. the code assumes there is a newline +// after the JSON message. this might not always be the case though so far, it +// seems to be. +// +// we should probably check if the rest of the payload starts with a newline +// and re-slice the returned buffer accordingly. +// +// this would greatly complicate the logic after io.CopyN() below though. +// +// the payload is the result of copying payloadSize bytes of the remainder of +// the json buffer and the conn. +// +// we return that buffer sans the newline at the beginning. +// +func (s *Service) readRequest(r io.Reader) (*Request, []byte, error) { + var req Request + d := json.NewDecoder(r) + if err := d.Decode(&req); err != nil { + return nil, nil, err } - bits := make([]byte, r.UploadSize+1) - - if r.UploadSize > 0 { - - remainder := d.Buffered() - - n, err := remainder.Read(bits) - if err != nil && err != io.EOF { - return r, bits, err - } - - // it is a bit random but sometimes the Json decoder will consume all the bytes and sometimes - // it will leave a few behind. - if err != io.EOF && n < int(r.UploadSize+1) { - _, err = conn.Read(bits[n:]) - } - - if err != nil && err != io.EOF { - return r, bits, err - } - // the JSON encoder on the client side seems to write an extra byte, so trim that off the front. - return r, bits[1:], nil + if req.UploadSize <= 0 { + return &req, nil, nil } - return r, bits, nil + payloadSize := req.UploadSize + 1 + + buf := &bytes.Buffer{} + buf.Grow(int(payloadSize)) + + switch nbytes, err := io.CopyN(buf, io.MultiReader(d.Buffered(), r), payloadSize); { + case err == io.EOF: + return nil, nil, fmt.Errorf("read %d of expected %d bytes of request payload", nbytes, payloadSize) + case err != nil: + return nil, nil, err + } + + return &req, buf.Bytes()[1:], nil } // RequestType indicates the typeof snapshot request.