chore: kit/io: improve LimitedReadCloser
A fairly minor change, but this saves two allocations every time points are written to the API (one allocation for the embedded io.LimitReader, and one allocation to create the `close` closure). Also fix the code so that it actually limits to the exact requested number of bytes rather than one more. We don't really need to layer on top of io.LimitReader, as that code is fairly minimal.pull/19654/head
parent
bc4bae3738
commit
239331c1ae
|
@ -11,38 +11,49 @@ var ErrReadLimitExceeded = errors.New("read limit exceeded")
|
|||
// io.LimitedReader. It allows us to obtain the limit error at the time of close
|
||||
// instead of just when writing.
|
||||
type LimitedReadCloser struct {
|
||||
*io.LimitedReader
|
||||
err error
|
||||
close func() error
|
||||
R io.ReadCloser // underlying reader
|
||||
N int64 // max bytes remaining
|
||||
err error
|
||||
closed bool
|
||||
limitExceeded bool
|
||||
}
|
||||
|
||||
// NewLimitedReadCloser returns a new LimitedReadCloser.
|
||||
func NewLimitedReadCloser(r io.ReadCloser, n int64) *LimitedReadCloser {
|
||||
// read up to max + 1 as limited reader just returns EOF when the limit is reached
|
||||
// or when there is nothing left to read. If we exceed the max batch size by one
|
||||
// then we know the limit has been passed.
|
||||
return &LimitedReadCloser{
|
||||
LimitedReader: &io.LimitedReader{R: r, N: n + 1},
|
||||
close: r.Close,
|
||||
R: r,
|
||||
N: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LimitedReadCloser) Read(p []byte) (n int, err error) {
|
||||
if l.N <= 0 {
|
||||
l.limitExceeded = true
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(p)) > l.N {
|
||||
p = p[0:l.N]
|
||||
}
|
||||
n, err = l.R.Read(p)
|
||||
l.N -= int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
// Close returns an ErrReadLimitExceeded when the wrapped reader exceeds the set
|
||||
// limit for number of bytes. This is safe to call more than once but not
|
||||
// concurrently.
|
||||
func (l *LimitedReadCloser) Close() (err error) {
|
||||
defer func() {
|
||||
if cerr := l.close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
|
||||
// only call close once
|
||||
l.close = func() error { return nil }
|
||||
}()
|
||||
|
||||
if l.N < 1 {
|
||||
if l.limitExceeded {
|
||||
l.err = ErrReadLimitExceeded
|
||||
}
|
||||
|
||||
if l.closed {
|
||||
// Close has already been called.
|
||||
return l.err
|
||||
}
|
||||
if err := l.R.Close(); err != nil && l.err == nil {
|
||||
l.err = err
|
||||
}
|
||||
// Prevent l.closer.Close from being called again.
|
||||
l.closed = true
|
||||
return l.err
|
||||
}
|
||||
|
|
|
@ -5,14 +5,15 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"errors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitedReadCloser_Exceeded(t *testing.T) {
|
||||
b := closer{bytes.NewBufferString("howdy")}
|
||||
rc := NewLimitedReadCloser(b, 2)
|
||||
b := &closer{Reader: bytes.NewBufferString("howdy")}
|
||||
rc := NewLimitedReadCloser(b, 3)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
|
@ -21,7 +22,7 @@ func TestLimitedReadCloser_Exceeded(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLimitedReadCloser_Happy(t *testing.T) {
|
||||
b := closer{bytes.NewBufferString("ho")}
|
||||
b := &closer{Reader: bytes.NewBufferString("ho")}
|
||||
rc := NewLimitedReadCloser(b, 2)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
|
@ -30,8 +31,57 @@ func TestLimitedReadCloser_Happy(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
type closer struct {
|
||||
io.Reader
|
||||
func TestLimitedReadCloseWithErrorAndLimitExceeded(t *testing.T) {
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: errors.New("some error"),
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 3)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("how"), out)
|
||||
// LimitExceeded error trumps the close error.
|
||||
assert.Equal(t, ErrReadLimitExceeded, rc.Close())
|
||||
}
|
||||
|
||||
func (c closer) Close() error { return nil }
|
||||
func TestLimitedReadCloseWithError(t *testing.T) {
|
||||
closeErr := errors.New("some error")
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: closeErr,
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 10)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("howdy"), out)
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
}
|
||||
|
||||
func TestMultipleCloseOnlyClosesOnce(t *testing.T) {
|
||||
closeErr := errors.New("some error")
|
||||
b := &closer{
|
||||
Reader: bytes.NewBufferString("howdy"),
|
||||
err: closeErr,
|
||||
}
|
||||
rc := NewLimitedReadCloser(b, 10)
|
||||
|
||||
out, err := ioutil.ReadAll(rc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("howdy"), out)
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
assert.Equal(t, closeErr, rc.Close())
|
||||
assert.Equal(t, 1, b.closeCount)
|
||||
}
|
||||
|
||||
type closer struct {
|
||||
io.Reader
|
||||
err error
|
||||
closeCount int
|
||||
}
|
||||
|
||||
func (c *closer) Close() error {
|
||||
c.closeCount++
|
||||
return c.err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue