chronograf/oauth2/generic.go

267 lines
6.8 KiB
Go

package oauth2
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
gojwt "github.com/golang-jwt/jwt/v4"
"github.com/influxdata/chronograf"
"golang.org/x/oauth2"
)
// ExtendedProvider extendts the base Provider interface with optional methods
type ExtendedProvider interface {
Provider
// get PrincipalID from id_token
PrincipalIDFromClaims(claims gojwt.MapClaims) (string, error)
GroupFromClaims(claims gojwt.MapClaims) (string, error)
}
var _ ExtendedProvider = &Generic{}
// Generic provides OAuth Login and Callback server and is modeled
// after the Github OAuth2 provider. Callback will set an authentication
// cookie. This cookie's value is a JWT containing the user's primary
// email address.
type Generic struct {
PageName string // Name displayed on the login page
ClientID string
ClientSecret string
RequiredScopes []string
Domains []string // Optional email domain checking
RedirectURL string
AuthURL string
TokenURL string
APIURL string // APIURL returns OpenID Userinfo
APIKey string // APIKey is the JSON key to lookup email address in APIURL response
Logger chronograf.Logger
}
// Name is the name of the provider
func (g *Generic) Name() string {
if g.PageName == "" {
return "generic"
}
return g.PageName
}
// ID returns the generic application client id
func (g *Generic) ID() string {
return g.ClientID
}
// Secret returns the generic application client secret
func (g *Generic) Secret() string {
return g.ClientSecret
}
// Scopes for generic provider required of the client.
func (g *Generic) Scopes() []string {
return g.RequiredScopes
}
// Config is the Generic OAuth2 exchange information and endpoints
func (g *Generic) Config() *oauth2.Config {
return &oauth2.Config{
ClientID: g.ID(),
ClientSecret: g.Secret(),
Scopes: g.Scopes(),
RedirectURL: g.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: g.AuthURL,
TokenURL: g.TokenURL,
},
}
}
// PrincipalID returns the email address of the user.
func (g *Generic) PrincipalID(provider *http.Client) (string, error) {
res := map[string]interface{}{}
r, err := provider.Get(g.APIURL)
if err != nil {
return "", err
}
defer r.Body.Close()
if err = json.NewDecoder(r.Body).Decode(&res); err != nil {
return "", err
}
email := ""
value := res[g.APIKey]
if e, ok := value.(string); ok {
email = e
}
// If we did not receive an email address, try to lookup the email
// in a similar way as github
if email == "" {
email, err = g.getPrimaryEmail(provider)
if err != nil {
return "", err
}
}
// If we need to restrict to a set of domains, we first get the org
// and filter.
if len(g.Domains) > 0 {
// If not in the domain deny permission
if ok := ofDomain(g.Domains, email); !ok {
msg := "Not a member of required domain"
g.Logger.Error(msg)
return "", fmt.Errorf(msg)
}
}
return email, nil
}
// Group returns the domain that a user belongs to in the
// the generic OAuth.
func (g *Generic) Group(provider *http.Client) (string, error) {
res := map[string]interface{}{}
r, err := provider.Get(g.APIURL)
if err != nil {
return "", err
}
defer r.Body.Close()
if err = json.NewDecoder(r.Body).Decode(&res); err != nil {
return "", err
}
email := ""
value := res[g.APIKey]
if e, ok := value.(string); ok {
email = e
}
// If we did not receive an email address, try to lookup the email
// in a similar way as github
if email == "" {
email, err = g.getPrimaryEmail(provider)
if err != nil {
return "", err
}
}
domain := strings.Split(email, "@")
if len(domain) != 2 {
return "", fmt.Errorf("malformed email address, expected %q to contain @ symbol", email)
}
return domain[1], nil
}
// UserEmail represents user's email address
type UserEmail struct {
Email *string `json:"email,omitempty"`
Primary *bool `json:"primary,omitempty"`
Verified *bool `json:"verified,omitempty"`
// support also indicators sent by bitbucket
IsPrimary *bool `json:"is_primary,omitempty"`
IsConfirmed *bool `json:"is_confirmed,omitempty"`
}
// WrappedUserEmails represents (bitbucket's) structure that wraps email addresses in a values field
type WrappedUserEmails struct {
Emails []*UserEmail `json:"values,omitempty"`
}
// getPrimaryEmail gets the private email account for the authenticated user.
func (g *Generic) getPrimaryEmail(client *http.Client) (string, error) {
emailsEndpoint := g.APIURL + "/emails"
r, err := client.Get(emailsEndpoint)
if err != nil {
return "", err
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return "", err
}
if len(body) == 0 {
return "", errors.New("No response body from /emails")
}
emails := []*UserEmail{}
if body[0] == '[' {
// array of UserEmail
if err = json.NewDecoder(bytes.NewReader(body)).Decode(&emails); err != nil {
return "", err
}
} else if body[0] == '{' {
// a struct with values that contain []*UserEmail{}
wrapped := WrappedUserEmails{}
if err = json.NewDecoder(bytes.NewReader(body)).Decode(&wrapped); err != nil {
return "", err
}
emails = wrapped.Emails
}
email, err := g.primaryEmail(emails)
if err != nil {
g.Logger.Error("Unable to retrieve primary email ", err.Error())
return "", err
}
return email, nil
}
func (g *Generic) primaryEmail(emails []*UserEmail) (string, error) {
var email string
for _, m := range emails {
if m != nil && m.Email != nil && ((m.Verified != nil) || (m.IsConfirmed != nil)) {
if email != "" {
email = *m.Email
}
if (m.Primary != nil && *m.Primary) || (m.IsPrimary != nil && *m.IsPrimary) {
return *m.Email, nil
}
}
}
if email == "" {
return "", errors.New("No primary email address")
}
return email, nil
}
// ofDomain makes sure that the email is in one of the required domains
func ofDomain(requiredDomains []string, email string) bool {
for _, domain := range requiredDomains {
emailDomain := fmt.Sprintf("@%s", domain)
if strings.HasSuffix(email, emailDomain) {
return true
}
}
return false
}
// PrincipalIDFromClaims verifies an optional id_token and extracts email address of the user
func (g *Generic) PrincipalIDFromClaims(claims gojwt.MapClaims) (string, error) {
if id, ok := claims[g.APIKey].(string); ok {
return id, nil
}
return "", fmt.Errorf("no claim for %s", g.APIKey)
}
// GroupFromClaims verifies an optional id_token, extracts the email address of the user and splits off the domain part
func (g *Generic) GroupFromClaims(claims gojwt.MapClaims) (string, error) {
if id, ok := claims[g.APIKey].(string); ok {
email := strings.Split(id, "@")
if len(email) != 2 {
g.Logger.Error("malformed email address, expected %q to contain @ symbol", id)
return "DEFAULT", nil
}
return email[1], nil
}
return "", fmt.Errorf("no claim for %s", g.APIKey)
}