influxdb/telemetry/push_test.go

151 lines
3.5 KiB
Go

package telemetry
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/matttproud/golang_protobuf_extensions/pbutil"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"google.golang.org/protobuf/proto"
)
func TestPusher_Push(t *testing.T) {
type check struct {
Method string
Body []byte
}
tests := []struct {
name string
gather prometheus.Gatherer
timeout time.Duration
status int
want check
wantErr bool
}{
{
name: "no metrics no push",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
return nil, nil
}),
},
{
name: "timeout while gathering data returns error",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
time.Sleep(time.Hour)
return nil, nil
}),
timeout: time.Millisecond,
wantErr: true,
},
{
name: "timeout server timeout data returns error",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
mf := &dto.MetricFamily{}
return []*dto.MetricFamily{mf}, nil
}),
timeout: time.Millisecond,
wantErr: true,
},
{
name: "error gathering metrics returns error",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
return nil, fmt.Errorf("e1")
}),
wantErr: true,
},
{
name: "status code that is not Accepted (202) is an error",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
mf := &dto.MetricFamily{}
return []*dto.MetricFamily{mf}, nil
}),
status: http.StatusInternalServerError,
want: check{
Method: http.MethodPost,
Body: []byte{0x00},
},
wantErr: true,
},
{
name: "sending metric are marshalled into delimited protobufs",
gather: prometheus.GathererFunc(func() ([]*dto.MetricFamily, error) {
mf := &dto.MetricFamily{
Name: proto.String("n1"),
Help: proto.String("h1"),
}
return []*dto.MetricFamily{mf}, nil
}),
status: http.StatusAccepted,
want: check{
Method: http.MethodPost,
Body: MustMarshal(&dto.MetricFamily{
Name: proto.String("n1"),
Help: proto.String("h1"),
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
if tt.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, tt.timeout)
defer cancel()
}
var got check
srv := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tt.timeout > 0 { // testing server timeouts
r = r.WithContext(ctx)
_ = r
<-ctx.Done()
return
}
got.Method = r.Method
got.Body, _ = io.ReadAll(r.Body)
w.WriteHeader(tt.status)
}),
)
defer srv.Close()
url := srv.URL
client := srv.Client()
p := &Pusher{
URL: url,
Gather: tt.gather,
Client: client,
PushFormat: expfmt.FmtProtoDelim,
}
if err := p.Push(ctx); (err != nil) != tt.wantErr {
t.Errorf("Pusher.Push() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !cmp.Equal(got, tt.want) {
t.Errorf("%q. Pusher.Push() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
t.Logf("%v\n%v", got.Body, tt.want.Body)
}
})
}
}
func MustMarshal(mf *dto.MetricFamily) []byte {
buf := &bytes.Buffer{}
_, err := pbutil.WriteDelimited(buf, mf)
if err != nil {
panic(err)
}
return buf.Bytes()
}