influxdb/prometheus/auth_service_test.go

151 lines
5.3 KiB
Go

package prometheus_test
import (
"context"
"errors"
"testing"
platform "github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/kit/prom"
"github.com/influxdata/influxdb/v2/kit/prom/promtest"
"github.com/influxdata/influxdb/v2/prometheus"
"go.uber.org/zap"
)
// authzSvc is a test helper that returns its Err from every method on the AuthorizationService interface.
type authzSvc struct {
Err error
}
var _ platform.AuthorizationService = (*authzSvc)(nil)
func (a *authzSvc) FindAuthorizationByID(context.Context, platform.ID) (*platform.Authorization, error) {
return nil, a.Err
}
func (a *authzSvc) FindAuthorizationByToken(context.Context, string) (*platform.Authorization, error) {
return nil, a.Err
}
func (a *authzSvc) FindAuthorizations(context.Context, platform.AuthorizationFilter, ...platform.FindOptions) ([]*platform.Authorization, int, error) {
return nil, 0, a.Err
}
func (a *authzSvc) CreateAuthorization(context.Context, *platform.Authorization) error {
return a.Err
}
func (a *authzSvc) DeleteAuthorization(context.Context, platform.ID) error {
return a.Err
}
func (a *authzSvc) UpdateAuthorization(context.Context, platform.ID, *platform.AuthorizationUpdate) (*platform.Authorization, error) {
return nil, a.Err
}
func TestAuthorizationService_Metrics(t *testing.T) {
a := new(authzSvc)
svc := prometheus.NewAuthorizationService()
svc.AuthorizationService = a
reg := prom.NewRegistry(zap.NewNop())
reg.MustRegister(svc.PrometheusCollectors()...)
ctx := context.Background()
id := platform.ID(1)
if _, err := svc.FindAuthorizationByID(ctx, id); err != nil {
t.Fatal(err)
}
mfs := promtest.MustGather(t, reg)
m := promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByID", "error": "false"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if _, err := svc.FindAuthorizationByToken(ctx, ""); err != nil {
t.Fatal(err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByToken", "error": "false"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if _, _, err := svc.FindAuthorizations(ctx, platform.AuthorizationFilter{}); err != nil {
t.Fatal(err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizations", "error": "false"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if err := svc.CreateAuthorization(ctx, nil); err != nil {
t.Fatal(err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "CreateAuthorization", "error": "false"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
var tempID platform.ID
if err := svc.DeleteAuthorization(ctx, tempID); err != nil {
t.Fatal(err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "DeleteAuthorization", "error": "false"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
forced := errors.New("forced error")
a.Err = forced
if _, err := svc.FindAuthorizationByID(ctx, id); err != forced {
t.Fatalf("expected forced error, got %v", err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByID", "error": "true"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if _, err := svc.FindAuthorizationByToken(ctx, ""); err != forced {
t.Fatalf("expected forced error, got %v", err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByToken", "error": "true"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if _, _, err := svc.FindAuthorizations(ctx, platform.AuthorizationFilter{}); err != forced {
t.Fatalf("expected forced error, got %v", err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizations", "error": "true"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if err := svc.CreateAuthorization(ctx, nil); err != forced {
t.Fatalf("expected forced error, got %v", err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "CreateAuthorization", "error": "true"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
if err := svc.DeleteAuthorization(ctx, tempID); err != forced {
t.Fatalf("expected forced error, got %v", err)
}
mfs = promtest.MustGather(t, reg)
m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "DeleteAuthorization", "error": "true"})
if got := m.GetCounter().GetValue(); got != 1 {
t.Fatalf("exp 1 request, got %v", got)
}
}