516 lines
12 KiB
Go
516 lines
12 KiB
Go
package write
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
platform "github.com/influxdata/influxdb/v2"
|
|
"github.com/influxdata/influxdb/v2/mock"
|
|
"github.com/influxdata/influxdb/v2/pkg/testing/assert"
|
|
)
|
|
|
|
func TestScanLines(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
want []string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "3 lines produced including their newlines",
|
|
input: "m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3",
|
|
want: []string{"m1,t1=v1 f1=1\n", "m2,t2=v2 f2=2\n", "m3,t3=v3 f3=3"},
|
|
},
|
|
{
|
|
name: "single line without newline",
|
|
input: "m1,t1=v1 f1=1",
|
|
want: []string{"m1,t1=v1 f1=1"},
|
|
},
|
|
{
|
|
name: "single line with newline",
|
|
input: "m1,t1=v1 f1=1\n",
|
|
want: []string{"m1,t1=v1 f1=1\n"},
|
|
},
|
|
{
|
|
name: "no lines",
|
|
input: "",
|
|
want: []string{},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
scanner := bufio.NewScanner(strings.NewReader(tt.input))
|
|
scanner.Split(ScanLines)
|
|
got := []string{}
|
|
for scanner.Scan() {
|
|
got = append(got, scanner.Text())
|
|
}
|
|
err := scanner.Err()
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ScanLines() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
if !cmp.Equal(got, tt.want) {
|
|
t.Errorf("%q. ScanLines() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// errorReader mocks io.Reader but returns an error.
|
|
type errorReader struct{}
|
|
|
|
func (r *errorReader) Read(p []byte) (n int, err error) {
|
|
return 0, fmt.Errorf("error")
|
|
}
|
|
|
|
func TestBatcher_read(t *testing.T) {
|
|
type args struct {
|
|
cancel bool
|
|
r io.Reader
|
|
max int
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want []string
|
|
expErr error
|
|
}{
|
|
{
|
|
name: "reading two lines produces 2 lines",
|
|
args: args{
|
|
r: strings.NewReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2"),
|
|
},
|
|
want: []string{"m1,t1=v1 f1=1\n", "m2,t2=v2 f2=2"},
|
|
},
|
|
{
|
|
name: "canceling returns no lines",
|
|
args: args{
|
|
cancel: true,
|
|
r: strings.NewReader("m1,t1=v1 f1=1"),
|
|
},
|
|
want: nil,
|
|
expErr: context.Canceled,
|
|
},
|
|
{
|
|
name: "error from reader returns error",
|
|
args: args{
|
|
r: &errorReader{},
|
|
},
|
|
want: nil,
|
|
expErr: fmt.Errorf("error"),
|
|
},
|
|
{
|
|
name: "error when input exceeds max line length",
|
|
args: args{
|
|
r: strings.NewReader("m1,t1=v1 f1=1"),
|
|
max: 5,
|
|
},
|
|
want: nil,
|
|
expErr: ErrLineTooLong,
|
|
},
|
|
{
|
|
name: "lines greater than MaxScanTokenSize are allowed",
|
|
args: args{
|
|
r: strings.NewReader(strings.Repeat("a", bufio.MaxScanTokenSize+1)),
|
|
max: bufio.MaxScanTokenSize + 2,
|
|
},
|
|
want: []string{strings.Repeat("a", bufio.MaxScanTokenSize+1)},
|
|
},
|
|
{
|
|
name: "lines greater than MaxScanTokenSize by default are not allowed",
|
|
args: args{
|
|
r: strings.NewReader(strings.Repeat("a", bufio.MaxScanTokenSize+1)),
|
|
},
|
|
want: nil,
|
|
expErr: ErrLineTooLong,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := context.Background()
|
|
var cancel context.CancelFunc
|
|
if tt.args.cancel {
|
|
ctx, cancel = context.WithCancel(ctx)
|
|
cancel()
|
|
}
|
|
|
|
b := &Batcher{MaxLineLength: tt.args.max}
|
|
var got []string
|
|
|
|
lines := make(chan []byte)
|
|
errC := make(chan error, 1)
|
|
|
|
go b.read(ctx, tt.args.r, lines, errC)
|
|
|
|
if cancel == nil {
|
|
for line := range lines {
|
|
got = append(got, string(line))
|
|
}
|
|
}
|
|
|
|
err := <-errC
|
|
assert.Equal(t, err, tt.expErr)
|
|
assert.Equal(t, got, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBatcher_write(t *testing.T) {
|
|
type fields struct {
|
|
MaxFlushBytes int
|
|
MaxFlushInterval time.Duration
|
|
}
|
|
type args struct {
|
|
cancel bool
|
|
writeError bool
|
|
org platform.ID
|
|
bucket platform.ID
|
|
line string
|
|
lines chan []byte
|
|
errC chan error
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "sending a single line will send a line to the service",
|
|
fields: fields{
|
|
MaxFlushBytes: 1,
|
|
},
|
|
args: args{
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
line: "m1,t1=v1 f1=1",
|
|
lines: make(chan []byte),
|
|
errC: make(chan error),
|
|
},
|
|
want: "m1,t1=v1 f1=1",
|
|
},
|
|
{
|
|
name: "sending a single line and waiting for a timeout will send a line to the service",
|
|
fields: fields{
|
|
MaxFlushInterval: time.Millisecond,
|
|
},
|
|
args: args{
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
line: "m1,t1=v1 f1=1",
|
|
lines: make(chan []byte),
|
|
errC: make(chan error),
|
|
},
|
|
want: "m1,t1=v1 f1=1",
|
|
},
|
|
{
|
|
name: "write service returning error stops the write after timeout",
|
|
fields: fields{
|
|
MaxFlushInterval: time.Millisecond,
|
|
},
|
|
args: args{
|
|
writeError: true,
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
line: "m1,t1=v1 f1=1",
|
|
lines: make(chan []byte),
|
|
errC: make(chan error),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "canceling will write no data to service",
|
|
fields: fields{
|
|
MaxFlushBytes: 1,
|
|
},
|
|
args: args{
|
|
cancel: true,
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
line: "m1,t1=v1 f1=1",
|
|
lines: make(chan []byte),
|
|
errC: make(chan error),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "write service returning error stops the write",
|
|
fields: fields{
|
|
MaxFlushBytes: 1,
|
|
},
|
|
args: args{
|
|
writeError: true,
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
line: "m1,t1=v1 f1=1",
|
|
lines: make(chan []byte),
|
|
errC: make(chan error),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := context.Background()
|
|
var cancel context.CancelFunc
|
|
if tt.args.cancel {
|
|
ctx, cancel = context.WithCancel(ctx)
|
|
}
|
|
|
|
// mocking the write service here to either return an error
|
|
// or get back all the bytes from the reader.
|
|
var got string
|
|
svc := &mock.WriteService{
|
|
WriteF: func(ctx context.Context, org, bucket platform.ID, r io.Reader) error {
|
|
if tt.args.writeError {
|
|
return fmt.Errorf("error")
|
|
}
|
|
b, err := ioutil.ReadAll(r)
|
|
got = string(b)
|
|
return err
|
|
},
|
|
WriteToF: func(ctx context.Context, filter platform.BucketFilter, r io.Reader) error {
|
|
if tt.args.writeError {
|
|
return fmt.Errorf("error")
|
|
}
|
|
b, err := ioutil.ReadAll(r)
|
|
got = string(b)
|
|
return err
|
|
},
|
|
}
|
|
|
|
b := &Batcher{
|
|
MaxFlushBytes: tt.fields.MaxFlushBytes,
|
|
MaxFlushInterval: tt.fields.MaxFlushInterval,
|
|
Service: svc,
|
|
}
|
|
writeFn := func(batch []byte) error {
|
|
return svc.Write(ctx, tt.args.org, tt.args.bucket, bytes.NewReader(batch))
|
|
}
|
|
|
|
go b.write(ctx, writeFn, tt.args.lines, tt.args.errC)
|
|
|
|
if cancel != nil {
|
|
cancel()
|
|
} else {
|
|
tt.args.lines <- []byte(tt.args.line)
|
|
// if the max flush interval is not zero, we are testing to see
|
|
// if the data is flushed via the timer rather than forced by
|
|
// closing the channel.
|
|
if tt.fields.MaxFlushInterval != 0 {
|
|
time.Sleep(tt.fields.MaxFlushInterval * 100)
|
|
}
|
|
close(tt.args.lines)
|
|
}
|
|
|
|
err := <-tt.args.errC
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ScanLines.read() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
if !cmp.Equal(got, tt.want) {
|
|
t.Errorf("%q. Batcher.write() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBatcher_Write(t *testing.T) {
|
|
createReader := func(data string) func() io.Reader {
|
|
if data == "error" {
|
|
return func() io.Reader {
|
|
return &errorReader{}
|
|
}
|
|
}
|
|
return func() io.Reader {
|
|
return strings.NewReader(data)
|
|
}
|
|
}
|
|
|
|
type fields struct {
|
|
MaxFlushBytes int
|
|
MaxFlushInterval time.Duration
|
|
}
|
|
type args struct {
|
|
writeError bool
|
|
org platform.ID
|
|
bucket platform.ID
|
|
r func() io.Reader
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want string
|
|
wantFlushes int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "a line of line protocol is sent to the service",
|
|
fields: fields{
|
|
MaxFlushBytes: 1,
|
|
},
|
|
args: args{
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
r: createReader("m1,t1=v1 f1=1"),
|
|
},
|
|
want: "m1,t1=v1 f1=1",
|
|
wantFlushes: 1,
|
|
},
|
|
{
|
|
name: "multiple lines cause multiple flushes",
|
|
fields: fields{
|
|
MaxFlushBytes: len([]byte("m1,t1=v1 f1=1\n")),
|
|
},
|
|
args: args{
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
r: createReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3"),
|
|
},
|
|
want: "m3,t3=v3 f3=3",
|
|
wantFlushes: 3,
|
|
},
|
|
{
|
|
name: "errors during read return error",
|
|
fields: fields{},
|
|
args: args{
|
|
org: platform.ID(1),
|
|
bucket: platform.ID(2),
|
|
r: createReader("error"),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// mocking the write service here to either return an error
|
|
// or get back all the bytes from the reader.
|
|
var (
|
|
got string
|
|
gotFlushes int
|
|
)
|
|
svc := &mock.WriteService{
|
|
WriteF: func(ctx context.Context, org, bucket platform.ID, r io.Reader) error {
|
|
if tt.args.writeError {
|
|
return fmt.Errorf("error")
|
|
}
|
|
b, err := ioutil.ReadAll(r)
|
|
got = string(b)
|
|
gotFlushes++
|
|
return err
|
|
},
|
|
}
|
|
|
|
b := &Batcher{
|
|
MaxFlushBytes: tt.fields.MaxFlushBytes,
|
|
MaxFlushInterval: tt.fields.MaxFlushInterval,
|
|
Service: svc,
|
|
}
|
|
|
|
ctx := context.Background()
|
|
if err := b.Write(ctx, tt.args.org, tt.args.bucket, tt.args.r()); (err != nil) != tt.wantErr {
|
|
t.Errorf("Batcher.Write() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
|
|
if gotFlushes != tt.wantFlushes {
|
|
t.Errorf("%q. Batcher.Write() flushes %d want %d", tt.name, gotFlushes, tt.wantFlushes)
|
|
}
|
|
if !cmp.Equal(got, tt.want) {
|
|
t.Errorf("%q. Batcher.Write() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
|
|
}
|
|
})
|
|
// test the same data, but now with WriteTo function
|
|
t.Run("WriteTo_"+tt.name, func(t *testing.T) {
|
|
// mocking the write service here to either return an error
|
|
// or get back all the bytes from the reader.
|
|
var (
|
|
got string
|
|
gotFlushes int
|
|
)
|
|
svc := &mock.WriteService{
|
|
WriteToF: func(ctx context.Context, filter platform.BucketFilter, r io.Reader) error {
|
|
if tt.args.writeError {
|
|
return fmt.Errorf("error")
|
|
}
|
|
b, err := ioutil.ReadAll(r)
|
|
got = string(b)
|
|
gotFlushes++
|
|
return err
|
|
},
|
|
}
|
|
|
|
b := &Batcher{
|
|
MaxFlushBytes: tt.fields.MaxFlushBytes,
|
|
MaxFlushInterval: tt.fields.MaxFlushInterval,
|
|
Service: svc,
|
|
}
|
|
|
|
ctx := context.Background()
|
|
bucketFilter := platform.BucketFilter{ID: &tt.args.bucket, OrganizationID: &tt.args.org}
|
|
if err := b.WriteTo(ctx, bucketFilter, tt.args.r()); (err != nil) != tt.wantErr {
|
|
t.Errorf("Batcher.Write() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
|
|
if gotFlushes != tt.wantFlushes {
|
|
t.Errorf("%q. Batcher.Write() flushes %d want %d", tt.name, gotFlushes, tt.wantFlushes)
|
|
}
|
|
if !cmp.Equal(got, tt.want) {
|
|
t.Errorf("%q. Batcher.Write() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBatcher_WriteTimeout(t *testing.T) {
|
|
// mocking the write service here to either return an error
|
|
// or get back all the bytes from the reader.
|
|
var got string
|
|
svc := &mock.WriteService{
|
|
WriteF: func(ctx context.Context, org, bucket platform.ID, r io.Reader) error {
|
|
b, err := ioutil.ReadAll(r)
|
|
got = string(b)
|
|
return err
|
|
},
|
|
}
|
|
|
|
b := &Batcher{
|
|
Service: svc,
|
|
}
|
|
|
|
// this mimics a reader like stdin that may never return data.
|
|
r, _ := io.Pipe()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
|
defer cancel()
|
|
|
|
if err := b.Write(ctx, platform.ID(1), platform.ID(2), r); err != context.DeadlineExceeded {
|
|
t.Errorf("Batcher.Write() with timeout error = %v", err)
|
|
}
|
|
|
|
if got != "" {
|
|
t.Errorf(" Batcher.Write() with timeout got %s", got)
|
|
}
|
|
}
|
|
|
|
func TestBatcher_WriteWithoutService(t *testing.T) {
|
|
b := Batcher{}
|
|
err := b.Write(context.Background(), platform.ID(1), platform.ID(1), strings.NewReader("m1,t1=v1 f1=1"))
|
|
if err == nil || !strings.Contains(err.Error(), "write service required") {
|
|
t.Errorf(" Batcher.Write() error expected, but got %v", err)
|
|
}
|
|
}
|