196 lines
4.2 KiB
Go
196 lines
4.2 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/bouk/httprouter"
|
|
"github.com/influxdata/influxdb/chronograf"
|
|
"github.com/influxdata/influxdb/chronograf/mocks"
|
|
"github.com/influxdata/influxdb/chronograf/oauth2"
|
|
)
|
|
|
|
func TestRouteMatchesPrincipal(t *testing.T) {
|
|
type fields struct {
|
|
OrganizationsStore chronograf.OrganizationsStore
|
|
Logger chronograf.Logger
|
|
}
|
|
type args struct {
|
|
useAuth bool
|
|
principal *oauth2.Principal
|
|
routerParams *httprouter.Params
|
|
}
|
|
type wants struct {
|
|
matches bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wants wants
|
|
}{
|
|
{
|
|
name: "route matches request params",
|
|
fields: fields{
|
|
Logger: &chronograf.NoopLogger{},
|
|
OrganizationsStore: &mocks.OrganizationsStore{
|
|
DefaultOrganizationF: func(ctx context.Context) (*chronograf.Organization, error) {
|
|
return &chronograf.Organization{
|
|
ID: "default",
|
|
}, nil
|
|
},
|
|
},
|
|
},
|
|
args: args{
|
|
useAuth: true,
|
|
principal: &oauth2.Principal{
|
|
Subject: "user",
|
|
Issuer: "github",
|
|
Organization: "default",
|
|
},
|
|
routerParams: &httprouter.Params{
|
|
{
|
|
Key: "oid",
|
|
Value: "default",
|
|
},
|
|
},
|
|
},
|
|
wants: wants{
|
|
matches: true,
|
|
},
|
|
},
|
|
{
|
|
name: "route does not match request params",
|
|
fields: fields{
|
|
Logger: &chronograf.NoopLogger{},
|
|
OrganizationsStore: &mocks.OrganizationsStore{
|
|
DefaultOrganizationF: func(ctx context.Context) (*chronograf.Organization, error) {
|
|
return &chronograf.Organization{
|
|
ID: "default",
|
|
}, nil
|
|
},
|
|
},
|
|
},
|
|
args: args{
|
|
useAuth: true,
|
|
principal: &oauth2.Principal{
|
|
Subject: "user",
|
|
Issuer: "github",
|
|
Organization: "default",
|
|
},
|
|
routerParams: &httprouter.Params{
|
|
{
|
|
Key: "oid",
|
|
Value: "other",
|
|
},
|
|
},
|
|
},
|
|
wants: wants{
|
|
matches: false,
|
|
},
|
|
},
|
|
{
|
|
name: "missing principal",
|
|
fields: fields{
|
|
Logger: &chronograf.NoopLogger{},
|
|
OrganizationsStore: &mocks.OrganizationsStore{
|
|
DefaultOrganizationF: func(ctx context.Context) (*chronograf.Organization, error) {
|
|
return &chronograf.Organization{
|
|
ID: "default",
|
|
}, nil
|
|
},
|
|
},
|
|
},
|
|
args: args{
|
|
useAuth: true,
|
|
principal: nil,
|
|
routerParams: &httprouter.Params{
|
|
{
|
|
Key: "oid",
|
|
Value: "other",
|
|
},
|
|
},
|
|
},
|
|
wants: wants{
|
|
matches: false,
|
|
},
|
|
},
|
|
{
|
|
name: "not using auth",
|
|
fields: fields{
|
|
Logger: &chronograf.NoopLogger{},
|
|
OrganizationsStore: &mocks.OrganizationsStore{
|
|
DefaultOrganizationF: func(ctx context.Context) (*chronograf.Organization, error) {
|
|
return &chronograf.Organization{
|
|
ID: "default",
|
|
}, nil
|
|
},
|
|
},
|
|
},
|
|
args: args{
|
|
useAuth: false,
|
|
principal: &oauth2.Principal{
|
|
Subject: "user",
|
|
Issuer: "github",
|
|
Organization: "default",
|
|
},
|
|
routerParams: &httprouter.Params{
|
|
{
|
|
Key: "oid",
|
|
Value: "other",
|
|
},
|
|
},
|
|
},
|
|
wants: wants{
|
|
matches: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
store := &mocks.Store{
|
|
OrganizationsStore: tt.fields.OrganizationsStore,
|
|
}
|
|
var matches bool
|
|
next := func(w http.ResponseWriter, r *http.Request) {
|
|
matches = true
|
|
}
|
|
fn := RouteMatchesPrincipal(
|
|
store,
|
|
tt.args.useAuth,
|
|
tt.fields.Logger,
|
|
next,
|
|
)
|
|
|
|
w := httptest.NewRecorder()
|
|
url := "http://any.url"
|
|
r := httptest.NewRequest(
|
|
"GET",
|
|
url,
|
|
nil,
|
|
)
|
|
if tt.args.routerParams != nil {
|
|
r = r.WithContext(httprouter.WithParams(r.Context(), *tt.args.routerParams))
|
|
}
|
|
if tt.args.principal == nil {
|
|
r = r.WithContext(context.WithValue(r.Context(), oauth2.PrincipalKey, nil))
|
|
} else {
|
|
r = r.WithContext(context.WithValue(r.Context(), oauth2.PrincipalKey, *tt.args.principal))
|
|
}
|
|
fn(w, r)
|
|
|
|
if matches != tt.wants.matches {
|
|
t.Errorf("%q. RouteMatchesPrincipal() = %v, expected %v", tt.name, matches, tt.wants.matches)
|
|
}
|
|
|
|
if !matches && w.Code != http.StatusForbidden {
|
|
t.Errorf("%q. RouteMatchesPrincipal() Status Code = %v, expected %v", tt.name, w.Code, http.StatusForbidden)
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|