influxdb/chronograf/oauth2/oauth2_test.go

129 lines
2.8 KiB
Go

package oauth2
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"time"
goauth "golang.org/x/oauth2"
gojwt "github.com/dgrijalva/jwt-go"
"github.com/influxdata/influxdb/chronograf"
)
var _ Provider = &MockProvider{}
type MockProvider struct {
Email string
Orgs string
ProviderURL string
}
func (mp *MockProvider) Config() *goauth.Config {
return &goauth.Config{
RedirectURL: "http://www.example.com",
ClientID: "4815162342",
ClientSecret: "8675309",
Endpoint: goauth.Endpoint{
AuthURL: mp.ProviderURL + "/oauth/auth",
TokenURL: mp.ProviderURL + "/oauth/token",
},
}
}
func (mp *MockProvider) ID() string {
return "8675309"
}
func (mp *MockProvider) Name() string {
return "mockly"
}
func (mp *MockProvider) PrincipalID(provider *http.Client) (string, error) {
return mp.Email, nil
}
func (mp *MockProvider) PrincipalIDFromClaims(claims gojwt.MapClaims) (string, error) {
return mp.Email, nil
}
func (mp *MockProvider) GroupFromClaims(claims gojwt.MapClaims) (string, error) {
email := strings.Split(mp.Email, "@")
if len(email) != 2 {
//g.Logger.Error("malformed email address, expected %q to contain @ symbol", id)
return "DEFAULT", nil
}
return email[1], nil
}
func (mp *MockProvider) Group(provider *http.Client) (string, error) {
return mp.Orgs, nil
}
func (mp *MockProvider) Scopes() []string {
return []string{}
}
func (mp *MockProvider) Secret() string {
return "4815162342"
}
var _ Tokenizer = &YesManTokenizer{}
type YesManTokenizer struct{}
func (y *YesManTokenizer) ValidPrincipal(ctx context.Context, token Token, duration time.Duration) (Principal, error) {
return Principal{
Subject: "biff@example.com",
Issuer: "Biff Tannen's Pleasure Paradise",
}, nil
}
func (y *YesManTokenizer) Create(ctx context.Context, p Principal) (Token, error) {
return Token("HELLO?!MCFLY?!ANYONEINTHERE?!"), nil
}
func (y *YesManTokenizer) ExtendedPrincipal(ctx context.Context, p Principal, ext time.Duration) (Principal, error) {
return p, nil
}
func (y *YesManTokenizer) GetClaims(tokenString string) (gojwt.MapClaims, error) {
return gojwt.MapClaims{}, nil
}
func NewTestTripper(log chronograf.Logger, ts *httptest.Server, rt http.RoundTripper) (*TestTripper, error) {
url, err := url.Parse(ts.URL)
if err != nil {
return nil, err
}
return &TestTripper{log, rt, url}, nil
}
type TestTripper struct {
Log chronograf.Logger
rt http.RoundTripper
tsURL *url.URL
}
// RoundTrip modifies the Hostname of the incoming request to be directed to the
// test server.
func (tt *TestTripper) RoundTrip(r *http.Request) (*http.Response, error) {
tt.Log.
WithField("component", "test").
WithField("remote_addr", r.RemoteAddr).
WithField("method", r.Method).
WithField("url", r.URL).
Info("Request")
r.URL.Host = tt.tsURL.Host
r.URL.Scheme = tt.tsURL.Scheme
return tt.rt.RoundTrip(r)
}