chore: kit/io: improve LimitedReadCloser

A fairly minor change, but this saves two allocations every time
points are written to the API (one allocation for the embedded io.LimitReader,
and one allocation to create the `close` closure).

Also fix the code so that it actually limits to the exact requested number of bytes
rather than one more. We don't really need to layer on top of io.LimitReader,
as that code is fairly minimal.
pull/19654/head
Roger Peppe 2020-09-28 18:28:51 +01:00
parent bc4bae3738
commit 239331c1ae
2 changed files with 86 additions and 25 deletions

View File

@ -11,38 +11,49 @@ var ErrReadLimitExceeded = errors.New("read limit exceeded")
// io.LimitedReader. It allows us to obtain the limit error at the time of close
// instead of just when writing.
type LimitedReadCloser struct {
*io.LimitedReader
err error
close func() error
R io.ReadCloser // underlying reader
N int64 // max bytes remaining
err error
closed bool
limitExceeded bool
}
// NewLimitedReadCloser returns a new LimitedReadCloser.
func NewLimitedReadCloser(r io.ReadCloser, n int64) *LimitedReadCloser {
// read up to max + 1 as limited reader just returns EOF when the limit is reached
// or when there is nothing left to read. If we exceed the max batch size by one
// then we know the limit has been passed.
return &LimitedReadCloser{
LimitedReader: &io.LimitedReader{R: r, N: n + 1},
close: r.Close,
R: r,
N: n,
}
}
func (l *LimitedReadCloser) Read(p []byte) (n int, err error) {
if l.N <= 0 {
l.limitExceeded = true
return 0, io.EOF
}
if int64(len(p)) > l.N {
p = p[0:l.N]
}
n, err = l.R.Read(p)
l.N -= int64(n)
return
}
// Close returns an ErrReadLimitExceeded when the wrapped reader exceeds the set
// limit for number of bytes. This is safe to call more than once but not
// concurrently.
func (l *LimitedReadCloser) Close() (err error) {
defer func() {
if cerr := l.close(); cerr != nil && err == nil {
err = cerr
}
// only call close once
l.close = func() error { return nil }
}()
if l.N < 1 {
if l.limitExceeded {
l.err = ErrReadLimitExceeded
}
if l.closed {
// Close has already been called.
return l.err
}
if err := l.R.Close(); err != nil && l.err == nil {
l.err = err
}
// Prevent l.closer.Close from being called again.
l.closed = true
return l.err
}

View File

@ -5,14 +5,15 @@ import (
"io"
"io/ioutil"
"testing"
"errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitedReadCloser_Exceeded(t *testing.T) {
b := closer{bytes.NewBufferString("howdy")}
rc := NewLimitedReadCloser(b, 2)
b := &closer{Reader: bytes.NewBufferString("howdy")}
rc := NewLimitedReadCloser(b, 3)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
@ -21,7 +22,7 @@ func TestLimitedReadCloser_Exceeded(t *testing.T) {
}
func TestLimitedReadCloser_Happy(t *testing.T) {
b := closer{bytes.NewBufferString("ho")}
b := &closer{Reader: bytes.NewBufferString("ho")}
rc := NewLimitedReadCloser(b, 2)
out, err := ioutil.ReadAll(rc)
@ -30,8 +31,57 @@ func TestLimitedReadCloser_Happy(t *testing.T) {
assert.Nil(t, err)
}
type closer struct {
io.Reader
func TestLimitedReadCloseWithErrorAndLimitExceeded(t *testing.T) {
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: errors.New("some error"),
}
rc := NewLimitedReadCloser(b, 3)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("how"), out)
// LimitExceeded error trumps the close error.
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
}
func (c closer) Close() error { return nil }
func TestLimitedReadCloseWithError(t *testing.T) {
closeErr := errors.New("some error")
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: closeErr,
}
rc := NewLimitedReadCloser(b, 10)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("howdy"), out)
assert.Equal(t, closeErr, rc.Close())
}
func TestMultipleCloseOnlyClosesOnce(t *testing.T) {
closeErr := errors.New("some error")
b := &closer{
Reader: bytes.NewBufferString("howdy"),
err: closeErr,
}
rc := NewLimitedReadCloser(b, 10)
out, err := ioutil.ReadAll(rc)
require.NoError(t, err)
assert.Equal(t, []byte("howdy"), out)
assert.Equal(t, closeErr, rc.Close())
assert.Equal(t, closeErr, rc.Close())
assert.Equal(t, 1, b.closeCount)
}
type closer struct {
io.Reader
err error
closeCount int
}
func (c *closer) Close() error {
c.closeCount++
return c.err
}