parent
f9b8ae32a5
commit
7f1d2be486
|
@ -3,6 +3,7 @@ package limiter_test
|
|||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,7 +14,7 @@ func TestWriter_Limited(t *testing.T) {
|
|||
r := bytes.NewReader(bytes.Repeat([]byte{0}, 1024*1024))
|
||||
|
||||
limit := 512 * 1024
|
||||
w := limiter.NewWriter(discardCloser{}, limit, 10*1024*1024)
|
||||
w := limiter.NewWriter(nopWriteCloser{ioutil.Discard}, limit, 10*1024*1024)
|
||||
|
||||
start := time.Now()
|
||||
n, err := io.Copy(w, r)
|
||||
|
@ -28,7 +29,26 @@ func TestWriter_Limited(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type discardCloser struct{}
|
||||
func TestWriter_Limiter_ExceedBurst(t *testing.T) {
|
||||
limit := 10
|
||||
burstLimit := 20
|
||||
|
||||
func (d discardCloser) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (d discardCloser) Close() error { return nil }
|
||||
twentyOneBytes := make([]byte, 21)
|
||||
|
||||
b := nopWriteCloser{bytes.NewBuffer(nil)}
|
||||
|
||||
w := limiter.NewWriter(b, limit, burstLimit)
|
||||
n, err := w.Write(twentyOneBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != len(twentyOneBytes) {
|
||||
t.Errorf("exected %d bytes written, but got %d", len(twentyOneBytes), n)
|
||||
}
|
||||
}
|
||||
|
||||
type nopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (d nopWriteCloser) Close() error { return nil }
|
||||
|
|
|
@ -17,6 +17,7 @@ type Writer struct {
|
|||
|
||||
type Rate interface {
|
||||
WaitN(ctx context.Context, n int) error
|
||||
Burst() int
|
||||
}
|
||||
|
||||
func NewRate(bytesPerSec, burstLimit int) Rate {
|
||||
|
@ -53,15 +54,25 @@ func (s *Writer) Write(b []byte) (int, error) {
|
|||
return s.w.Write(b)
|
||||
}
|
||||
|
||||
n, err := s.w.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
var n int
|
||||
for n < len(b) {
|
||||
wantToWriteN := len(b[n:])
|
||||
if wantToWriteN > s.limiter.Burst() {
|
||||
wantToWriteN = s.limiter.Burst()
|
||||
}
|
||||
|
||||
wroteN, err := s.w.Write(b[n : n+wantToWriteN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
n += wroteN
|
||||
|
||||
if err := s.limiter.WaitN(s.ctx, wroteN); err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *Writer) Sync() error {
|
||||
|
|
Loading…
Reference in New Issue