Merge pull request #5710 from influxdata/5394/oauth_pkce
feat(oauth): add PKCE to OAuth integrationspull/5713/head
commit
600a21ee70
|
@ -24,6 +24,7 @@
|
|||
1. [#5700](https://github.com/influxdata/chronograf/pull/5700): Remove HipChat alerts.
|
||||
1. [#5704](https://github.com/influxdata/chronograf/pull/5704): Allow to filter fields in Query Builder UI.
|
||||
1. [#5712](https://github.com/influxdata/chronograf/pull/5712): Allow to change write precission.
|
||||
1. [#5710](https://github.com/influxdata/chronograf/pull/5710): Add PKCE to OAuth integrations.
|
||||
|
||||
### Other
|
||||
|
||||
|
@ -33,6 +34,14 @@
|
|||
1. [#5701](https://github.com/influxdata/chronograf/pull/5690): Fix unsafe React lifecycle functions.
|
||||
1. [#5706](https://github.com/influxdata/chronograf/pull/5706): Improve communication with InfluxDB Enterprise.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
1. [#5710](https://github.com/influxdata/chronograf/pull/5710): OAuth integrations newly use OAuth PKCE (RFC7636)
|
||||
to provide a more secure OAuth token exchange. Google, Azure, Octa, Auth0, Gitlab (and more) integrations already
|
||||
support OAuth PKCE. PKCE enablement should have no effect on the communication with authorization servers that
|
||||
don't support it yet (such as Github, Bitbucket). PKCE can be eventually turned off with `OAUTH_NO_PKCE=true`
|
||||
environment variable.
|
||||
|
||||
## v1.8.10 [2021-02-08]
|
||||
|
||||
### Bug Fixes
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// CodeExchange helps to ensure secure exchange
|
||||
// of the authorization code for token.
|
||||
type CodeExchange interface {
|
||||
// AuthCodeURL generates authorization URL with a state that prevents CSRF attacks.
|
||||
// It can also use OAuth2 PKCE.
|
||||
AuthCodeURL(ctx context.Context, j *AuthMux) (string, error)
|
||||
// ExchangeCodeForToken receives authorization token having a verified state and code
|
||||
ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
// default implementation
|
||||
var simpleTokenExchange CodeExchange = &CodeExchangeCSRF{}
|
||||
|
||||
func NewCodeExchange(withPKCE bool, secret string) CodeExchange {
|
||||
if withPKCE {
|
||||
return &CodeExchangePKCE{Secret: secret}
|
||||
}
|
||||
return simpleTokenExchange
|
||||
}
|
||||
|
||||
// CodeExchangeCSRF prevents CSRF attacks during retrieval of OAuth token
|
||||
// by using a signed random state in the exchange with authorization server.
|
||||
// It uses a random string as the state validation method. The state is a JWT. It is
|
||||
// a good choice here for encoding because they can be validated without
|
||||
// storing state.
|
||||
type CodeExchangeCSRF struct {
|
||||
}
|
||||
|
||||
// AuthCodeURL generates authorization URL with a state that prevents CSRF attacks.
|
||||
func (p *CodeExchangeCSRF) AuthCodeURL(ctx context.Context, j *AuthMux) (string, error) {
|
||||
// We are creating a token with an encoded random string to prevent CSRF attacks
|
||||
// This token will be validated during the OAuth callback.
|
||||
// We'll give our users 10 minutes from this point to type in their
|
||||
// oauth2 provider's password.
|
||||
// If the callback is not received within 10 minutes, then authorization will fail.
|
||||
csrf := make([]byte, 32) // 32 is not important... just long
|
||||
if _, err := io.ReadFull(rand.Reader, csrf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
now := j.Now()
|
||||
|
||||
// This token will be valid for 10 minutes. Any chronograf server will
|
||||
// be able to validate this token.
|
||||
stateData := Principal{
|
||||
Subject: base64.RawURLEncoding.EncodeToString(csrf),
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(TenMinutes),
|
||||
}
|
||||
token, err := j.Tokens.Create(ctx, stateData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}
|
||||
if j.LoginHint != "" {
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
|
||||
}
|
||||
url := j.Provider.Config().AuthCodeURL(string(token), urlOpts...)
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (p *CodeExchangeCSRF) ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error) {
|
||||
// Check if the OAuth state token is valid to prevent CSRF
|
||||
// The state variable we set is actually a token. We'll check
|
||||
// if the token is valid. We don't need to know anything
|
||||
// about the contents of the principal only that it hasn't expired.
|
||||
if _, err := j.Tokens.ValidPrincipal(ctx, Token(state), TenMinutes); err != nil {
|
||||
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
||||
}
|
||||
|
||||
// Exchange the code back with the provider to the the token
|
||||
conf := j.Provider.Config()
|
||||
|
||||
// Use http client with transport options.
|
||||
if j.client != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, j.client)
|
||||
}
|
||||
|
||||
return conf.Exchange(ctx, code)
|
||||
}
|
||||
|
||||
// CodeExchangePKCE extends CodeExchangeCSRF and adds OAuth2 PKCE
|
||||
// to protect against interception attacks. See RFC7636 for details.
|
||||
type CodeExchangePKCE struct {
|
||||
// Secret is used encrypt PKCE code verifier in the state data. The state
|
||||
// data are chosen to avoid the need for the server to remember the sate.
|
||||
Secret string
|
||||
}
|
||||
|
||||
// Encrypt encrypts the codeVerifier supplied
|
||||
func (c *CodeExchangePKCE) Encrypt(codeVerifier []byte) (string, error) {
|
||||
// create a AES256 key out of secret
|
||||
key := sha256.Sum256([]byte(c.Secret))
|
||||
|
||||
// create a cipher
|
||||
block, err := aes.NewCipher(key[:32])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// create a nonce for the cipher
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encrypt the data
|
||||
cipherText := aesGCM.Seal(nonce, nonce, codeVerifier, nil)
|
||||
return base64.RawURLEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the supplied encrypted string
|
||||
func (c *CodeExchangePKCE) Decrypt(encrypted string) ([]byte, error) {
|
||||
// create a AES256 key out of secret
|
||||
key := sha256.Sum256([]byte(c.Secret))
|
||||
enc, err := base64.RawURLEncoding.Strict().DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
// create a new cipher
|
||||
block, err := aes.NewCipher(key[:32])
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
// separate nonce and cipherText
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(enc) <= nonceSize {
|
||||
return []byte{}, errors.New("malformed encrypted data")
|
||||
}
|
||||
nonce, cipherText := enc[:nonceSize], enc[nonceSize:]
|
||||
|
||||
// decrypt
|
||||
plain, err := aesGCM.Open(nil, nonce, cipherText, nil)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
return plain, nil
|
||||
|
||||
}
|
||||
|
||||
// AuthCodeURL generates authorization URL with PKCE
|
||||
// challenge parameters and a state that prevents CSRF and.
|
||||
func (p *CodeExchangePKCE) AuthCodeURL(ctx context.Context, j *AuthMux) (string, error) {
|
||||
// generate code verifier
|
||||
codeVerifier := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, codeVerifier); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encrypt code verifier so it can be added to state,
|
||||
// we don't need to remember it on the server side then
|
||||
encryptedCodeVerifier, err := p.Encrypt(codeVerifier)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
now := j.Now()
|
||||
// This token will be valid for 10 minutes. Any chronograf server will
|
||||
// be able to validate this token.
|
||||
stateData := Principal{
|
||||
Subject: encryptedCodeVerifier,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(TenMinutes),
|
||||
}
|
||||
token, err := j.Tokens.Create(ctx, stateData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// setup URL options including PKCE code challenge
|
||||
codeVerifierDigest := sha256.Sum256([]byte(
|
||||
base64.RawURLEncoding.EncodeToString(codeVerifier),
|
||||
))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(
|
||||
codeVerifierDigest[:],
|
||||
)
|
||||
urlOpts := make([]oauth2.AuthCodeOption, 0, 4)
|
||||
urlOpts = append(urlOpts,
|
||||
oauth2.AccessTypeOnline,
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
if j.LoginHint != "" {
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
|
||||
}
|
||||
url := j.Provider.Config().AuthCodeURL(string(token), urlOpts...)
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (p *CodeExchangePKCE) ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error) {
|
||||
// Check if the OAuth state token is valid.
|
||||
stateData, err := j.Tokens.ValidPrincipal(ctx, Token(state), TenMinutes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
||||
}
|
||||
codeVerifier, err := p.Decrypt(stateData.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
||||
}
|
||||
|
||||
// Exchange the code back with the provider to the the token
|
||||
conf := j.Provider.Config()
|
||||
|
||||
// Use http client with transport options.
|
||||
if j.client != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, j.client)
|
||||
}
|
||||
|
||||
// exhange code for token with PKCE code_verifier supplied
|
||||
return conf.Exchange(ctx,
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", base64.RawURLEncoding.EncodeToString(codeVerifier)),
|
||||
)
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clog "github.com/influxdata/chronograf/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_CodeExchangeCSRF_AuthCodeURL(t *testing.T) {
|
||||
// setup auth mux
|
||||
mt := &YesManTokenizer{}
|
||||
auth := &cookie{
|
||||
Name: DefaultCookieName,
|
||||
Lifespan: 1 * time.Hour,
|
||||
Inactivity: defaultInactivityDuration,
|
||||
Now: func() time.Time {
|
||||
return testTime
|
||||
},
|
||||
Tokens: mt,
|
||||
}
|
||||
useidtoken := false
|
||||
mp := &MockProvider{
|
||||
Email: "biff@example.com",
|
||||
ProviderURL: "http://localhost:1234",
|
||||
Orgs: "",
|
||||
}
|
||||
authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hello", nil, nil)
|
||||
|
||||
// create AuthCodeURL with code exchange without PKCE
|
||||
codeExchange := NewCodeExchange(false, "")
|
||||
authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux)
|
||||
require.Nil(t, err, "Unable to generate AuthCodeURL")
|
||||
authCodeURL, err := url.ParseRequestURI(authCodeURLString)
|
||||
require.Nil(t, err, "AuthCodeURL format invalid")
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"access_type": "online",
|
||||
"client_id": "",
|
||||
"state": "",
|
||||
"response_type": "code",
|
||||
"redirect_uri": "",
|
||||
"login_hint": "hello",
|
||||
}
|
||||
queryParams := authCodeURL.Query()
|
||||
for key, val := range expectedParams {
|
||||
foundVal := queryParams.Get(key)
|
||||
if foundVal == "" {
|
||||
t.Errorf("Generated AuthCodeURL does not contain '%v' param", key)
|
||||
continue
|
||||
}
|
||||
if val != "" && val != foundVal {
|
||||
t.Errorf("Generated AuthCodeURL contains different '%v' param; expected: %s, got: %s", key, val, foundVal)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(expectedParams) != len(queryParams) {
|
||||
t.Errorf("Generated AuthCodeURL contains %d params; expected: %d, url: %s", len(queryParams), len(expectedParams), authCodeURL)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CodeExchangeCSRF_ExchangeCodeForToken(t *testing.T) {
|
||||
// mock authorization provider
|
||||
const testToken = "fake.token"
|
||||
exchangeUrlValues := url.Values{}
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("content-type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
r.ParseForm()
|
||||
exchangeUrlValues = r.Form
|
||||
|
||||
body, _ := json.Marshal(mockCallbackResponse{AccessToken: testToken})
|
||||
|
||||
rw.Write(body)
|
||||
}))
|
||||
defer authServer.Close()
|
||||
|
||||
// setup auth mux
|
||||
auth := &cookie{
|
||||
Name: DefaultCookieName,
|
||||
Lifespan: 1 * time.Hour,
|
||||
Inactivity: defaultInactivityDuration,
|
||||
Tokens: &YesManTokenizer{},
|
||||
}
|
||||
useidtoken := false
|
||||
mp := &MockProvider{
|
||||
Email: "biff@example.com",
|
||||
ProviderURL: authServer.URL,
|
||||
Orgs: "",
|
||||
}
|
||||
authMux := NewAuthMux(mp, auth, auth.Tokens, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil)
|
||||
|
||||
// create AuthCodeURL using CodeExchange with PKCE
|
||||
codeExchange := simpleTokenExchange
|
||||
authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux)
|
||||
require.Nil(t, err, "Unable to generate AuthCodeURL")
|
||||
authCodeURL, err := url.ParseRequestURI(authCodeURLString)
|
||||
require.Nil(t, err, "AuthCodeURL format invalid")
|
||||
state := authCodeURL.Query().Get("state")
|
||||
token, err := codeExchange.ExchangeCodeForToken(context.Background(), state, "any code", authMux)
|
||||
require.Nil(t, err, "ExchangeCodeForToken ends with error")
|
||||
require.NotNil(t, token, "No token received")
|
||||
require.Equal(t, testToken, token.AccessToken)
|
||||
expectedParams := []string{"code"}
|
||||
for _, key := range expectedParams {
|
||||
foundVal := exchangeUrlValues.Get(key)
|
||||
if foundVal == "" {
|
||||
t.Errorf("Authorization server did not receive the required %s parameter; values=%v", key, exchangeUrlValues)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CodeExchangePKCE_AuthCodeURL(t *testing.T) {
|
||||
// setup auth mux
|
||||
mt := &YesManTokenizer{}
|
||||
auth := &cookie{
|
||||
Name: DefaultCookieName,
|
||||
Lifespan: 1 * time.Hour,
|
||||
Inactivity: defaultInactivityDuration,
|
||||
Now: func() time.Time {
|
||||
return testTime
|
||||
},
|
||||
Tokens: mt,
|
||||
}
|
||||
useidtoken := false
|
||||
mp := &MockProvider{
|
||||
Email: "biff@example.com",
|
||||
ProviderURL: "http://localhost:1234",
|
||||
Orgs: "",
|
||||
}
|
||||
authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil)
|
||||
|
||||
// create AuthCodeURL using CodeExchange with PKCE
|
||||
codeExchange := NewCodeExchange(true, "secret")
|
||||
authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux)
|
||||
require.Nil(t, err, "Unable to generate AuthCodeURL")
|
||||
authCodeURL, err := url.ParseRequestURI(authCodeURLString)
|
||||
require.Nil(t, err, "Invalid AuthCodeURL format")
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"access_type": "online",
|
||||
"client_id": "",
|
||||
"state": "",
|
||||
"response_type": "code",
|
||||
"redirect_uri": "",
|
||||
"code_challenge": "",
|
||||
"code_challenge_method": "",
|
||||
"login_hint": "hi",
|
||||
}
|
||||
queryParams := authCodeURL.Query()
|
||||
for key, val := range expectedParams {
|
||||
foundVal := queryParams.Get(key)
|
||||
if foundVal == "" {
|
||||
t.Errorf("Generated AuthCodeURL does not contain '%v' param", key)
|
||||
continue
|
||||
}
|
||||
if val != "" && val != foundVal {
|
||||
t.Errorf("Generated AuthCodeURL contains different '%v' param; expected: %s, got: %s", key, val, foundVal)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(expectedParams) != len(queryParams) {
|
||||
t.Errorf("Generated AuthCodeURL contains %d params; expected: %d, url: %s", len(queryParams), len(expectedParams), authCodeURL)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CodeExchangePKCE_EncryptDecrypt(t *testing.T) {
|
||||
codeExchange := &CodeExchangePKCE{Secret: "hardtoguess"}
|
||||
plain := []byte("this is a test")
|
||||
encrypted, err := codeExchange.Encrypt([]byte(plain))
|
||||
require.Nil(t, err, "Unable to encrypt plain text ")
|
||||
decrypted, err := codeExchange.Decrypt(encrypted)
|
||||
require.Nil(t, err, "Unable to decrypt plain text")
|
||||
require.Equal(t, plain, decrypted)
|
||||
}
|
||||
|
||||
func Test_CodeExchangePKCE_ExchangeCodeForToken(t *testing.T) {
|
||||
// mock authorization provider
|
||||
const testToken = "fake.token"
|
||||
exchangeUrlValues := url.Values{}
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("content-type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
r.ParseForm()
|
||||
exchangeUrlValues = r.Form
|
||||
|
||||
body, _ := json.Marshal(mockCallbackResponse{AccessToken: testToken})
|
||||
|
||||
rw.Write(body)
|
||||
}))
|
||||
defer authServer.Close()
|
||||
|
||||
// setup auth mux
|
||||
secret := "this is a test secret"
|
||||
jwt := NewJWT(secret, "")
|
||||
auth := &cookie{
|
||||
Name: DefaultCookieName,
|
||||
Lifespan: 1 * time.Hour,
|
||||
Inactivity: defaultInactivityDuration,
|
||||
Tokens: jwt,
|
||||
}
|
||||
useidtoken := false
|
||||
mp := &MockProvider{
|
||||
Email: "biff@example.com",
|
||||
ProviderURL: authServer.URL,
|
||||
Orgs: "",
|
||||
}
|
||||
authMux := NewAuthMux(mp, auth, jwt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil)
|
||||
|
||||
// create AuthCodeURL using CodeExchange with PKCE
|
||||
codeExchange := CodeExchangePKCE{Secret: secret}
|
||||
authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux)
|
||||
require.Nil(t, err, "Unable to generate AuthCodeURL")
|
||||
authCodeURL, err := url.ParseRequestURI(authCodeURLString)
|
||||
require.Nil(t, err, "Invalid AuthCodeURL format")
|
||||
state := authCodeURL.Query().Get("state")
|
||||
token, err := codeExchange.ExchangeCodeForToken(context.Background(), state, "any code", authMux)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, token)
|
||||
require.Equal(t, testToken, token.AccessToken)
|
||||
stateData, err := jwt.ValidPrincipal(context.Background(), Token(state), TenMinutes)
|
||||
require.Nil(t, err, "invalid OAuth state received")
|
||||
codeVerifierBytes, err := codeExchange.Decrypt(stateData.Subject)
|
||||
require.Nil(t, err, "unable to decrypt code verifier")
|
||||
codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBytes)
|
||||
expectedParams := map[string]string{
|
||||
"code": "",
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
if len(expectedParams["code_verifier"]) < 43 {
|
||||
t.Errorf("Code verifier must be at least 43 characters long, but it is %d; code_verifier=%s", len(codeVerifier), codeVerifier)
|
||||
}
|
||||
for key, val := range expectedParams {
|
||||
foundVal := exchangeUrlValues.Get(key)
|
||||
if foundVal == "" {
|
||||
t.Errorf("Authorization server did not receive the required %s parameter; values=%v", key, exchangeUrlValues)
|
||||
continue
|
||||
}
|
||||
if val != "" && val != foundVal {
|
||||
t.Errorf("Authorization Server receveid different '%v' param; expected: %s, got: %s", key, val, foundVal)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,10 +2,7 @@ package oauth2
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -106,14 +103,6 @@ func (g *Github) Group(provider *http.Client) (string, error) {
|
|||
return strings.Join(groups, ","), nil
|
||||
}
|
||||
|
||||
func randomString(length int) string {
|
||||
k := make([]byte, length)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return ""
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(k)
|
||||
}
|
||||
|
||||
func logResponseError(log chronograf.Logger, resp *github.Response, err error) {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
|
|
104
oauth2/mux.go
104
oauth2/mux.go
|
@ -1,13 +1,11 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/chronograf"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Check to ensure AuthMux is an oauth2.Mux
|
||||
|
@ -17,17 +15,25 @@ var _ Mux = &AuthMux{}
|
|||
const TenMinutes = 10 * time.Minute
|
||||
|
||||
// NewAuthMux constructs a Mux handler that checks a cookie against the authenticator
|
||||
func NewAuthMux(p Provider, a Authenticator, t Tokenizer, basepath string, l chronograf.Logger, UseIDToken bool, LoginHint string, client *http.Client) *AuthMux {
|
||||
func NewAuthMux(p Provider, a Authenticator, t Tokenizer,
|
||||
basepath string, l chronograf.Logger,
|
||||
UseIDToken bool, LoginHint string,
|
||||
client *http.Client, codeExchange CodeExchange,
|
||||
) *AuthMux {
|
||||
if codeExchange == nil {
|
||||
codeExchange = simpleTokenExchange
|
||||
}
|
||||
mux := &AuthMux{
|
||||
Provider: p,
|
||||
Auth: a,
|
||||
Tokens: t,
|
||||
SuccessURL: path.Join(basepath, "/"),
|
||||
FailureURL: path.Join(basepath, "/login"),
|
||||
Now: DefaultNowTime,
|
||||
Logger: l,
|
||||
UseIDToken: UseIDToken,
|
||||
LoginHint: LoginHint,
|
||||
Provider: p,
|
||||
Auth: a,
|
||||
Tokens: t,
|
||||
SuccessURL: path.Join(basepath, "/"),
|
||||
FailureURL: path.Join(basepath, "/login"),
|
||||
Now: DefaultNowTime,
|
||||
Logger: l,
|
||||
UseIDToken: UseIDToken,
|
||||
LoginHint: LoginHint,
|
||||
CodeExchange: codeExchange,
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
|
@ -43,42 +49,23 @@ func NewAuthMux(p Provider, a Authenticator, t Tokenizer, basepath string, l chr
|
|||
// Chronograf instance as long as the Authenticator has no external
|
||||
// dependencies (e.g. on a Database).
|
||||
type AuthMux struct {
|
||||
Provider Provider // Provider is the OAuth2 service
|
||||
Auth Authenticator // Auth is used to Authorize after successful OAuth2 callback and Expire on Logout
|
||||
Tokens Tokenizer // Tokens is used to create and validate OAuth2 "state"
|
||||
Logger chronograf.Logger // Logger is used to give some more information about the OAuth2 process
|
||||
SuccessURL string // SuccessURL is redirect location after successful authorization
|
||||
FailureURL string // FailureURL is redirect location after authorization failure
|
||||
Now func() time.Time // Now returns the current time (for testing)
|
||||
UseIDToken bool // UseIDToken enables OpenID id_token support
|
||||
LoginHint string // LoginHint will be included as a parameter during authentication if non-nil
|
||||
client *http.Client // client is the http client used in oauth exchange.
|
||||
Provider Provider // Provider is the OAuth2 service
|
||||
Auth Authenticator // Auth is used to Authorize after successful OAuth2 callback and Expire on Logout
|
||||
Tokens Tokenizer // Tokens is used to create and validate OAuth2 "state"
|
||||
Logger chronograf.Logger // Logger is used to give some more information about the OAuth2 process
|
||||
SuccessURL string // SuccessURL is redirect location after successful authorization
|
||||
FailureURL string // FailureURL is redirect location after authorization failure
|
||||
Now func() time.Time // Now returns the current time (for testing)
|
||||
UseIDToken bool // UseIDToken enables OpenID id_token support
|
||||
LoginHint string // LoginHint will be included as a parameter during authentication if non-nil
|
||||
client *http.Client // client is the http client used in oauth exchange.
|
||||
CodeExchange CodeExchange // helps with CSRF in exchange of token for authorization code
|
||||
}
|
||||
|
||||
// Login uses a Cookie with a random string as the state validation method. JWTs are
|
||||
// a good choice here for encoding because they can be validated without
|
||||
// storing state. Login returns a handler that redirects to the providers OAuth login.
|
||||
// Login returns a handler that redirects to the providers OAuth login.
|
||||
func (j *AuthMux) Login() http.Handler {
|
||||
conf := j.Provider.Config()
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// We are creating a token with an encoded random string to prevent CSRF attacks
|
||||
// This token will be validated during the OAuth callback.
|
||||
// We'll give our users 10 minutes from this point to type in their
|
||||
// oauth2 provider's password.
|
||||
// If the callback is not received within 10 minutes, then authorization will fail.
|
||||
csrf := randomString(32) // 32 is not important... just long
|
||||
now := j.Now()
|
||||
|
||||
// This token will be valid for 10 minutes. Any chronograf server will
|
||||
// be able to validate this token.
|
||||
p := Principal{
|
||||
Subject: csrf,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(TenMinutes),
|
||||
}
|
||||
token, err := j.Tokens.Create(r.Context(), p)
|
||||
|
||||
// This is likely an internal server error
|
||||
url, err := j.CodeExchange.AuthCodeURL(r.Context(), j)
|
||||
if err != nil {
|
||||
j.Logger.
|
||||
WithField("component", "auth").
|
||||
|
@ -89,13 +76,6 @@ func (j *AuthMux) Login() http.Handler {
|
|||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}
|
||||
if j.LoginHint != "" {
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
|
||||
}
|
||||
url := conf.AuthCodeURL(string(token), urlOpts...)
|
||||
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
})
|
||||
}
|
||||
|
@ -114,27 +94,9 @@ func (j *AuthMux) Callback() http.Handler {
|
|||
WithField("url", r.URL)
|
||||
|
||||
state := r.FormValue("state")
|
||||
// Check if the OAuth state token is valid to prevent CSRF
|
||||
// The state variable we set is actually a token. We'll check
|
||||
// if the token is valid. We don't need to know anything
|
||||
// about the contents of the principal only that it hasn't expired.
|
||||
if _, err := j.Tokens.ValidPrincipal(r.Context(), Token(state), TenMinutes); err != nil {
|
||||
log.Error("Invalid OAuth state received: ", err.Error())
|
||||
http.Redirect(w, r, j.FailureURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange the code back with the provider to the the token
|
||||
conf := j.Provider.Config()
|
||||
code := r.FormValue("code")
|
||||
|
||||
// Use http client with transport options.
|
||||
ctx := r.Context()
|
||||
if j.client != nil {
|
||||
ctx = context.WithValue(r.Context(), oauth2.HTTPClient, j.client)
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code)
|
||||
token, err := j.CodeExchange.ExchangeCodeForToken(r.Context(), state, code, j)
|
||||
if err != nil {
|
||||
log.Error("Unable to exchange code for token ", err.Error())
|
||||
http.Redirect(w, r, j.FailureURL, http.StatusTemporaryRedirect)
|
||||
|
@ -182,7 +144,7 @@ func (j *AuthMux) Callback() http.Handler {
|
|||
}
|
||||
} else {
|
||||
// otherwise perform an additional lookup
|
||||
oauthClient := conf.Client(r.Context(), token)
|
||||
oauthClient := j.Provider.Config().Client(r.Context(), token)
|
||||
// Using the token get the principal identifier from the provider
|
||||
id, err = j.Provider.PrincipalID(oauthClient)
|
||||
if err != nil {
|
||||
|
|
|
@ -53,7 +53,7 @@ func setupMuxTest(response interface{}, selector func(*AuthMux) http.Handler) (*
|
|||
|
||||
useidtoken := false
|
||||
|
||||
jm := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "", nil)
|
||||
jm := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "", nil, nil)
|
||||
ts := httptest.NewServer(selector(jm))
|
||||
jar, _ := cookiejar.New(nil)
|
||||
hc := http.Client{
|
||||
|
|
|
@ -24,7 +24,7 @@ var (
|
|||
// ErrAuthentication means that oauth2 exchange failed
|
||||
ErrAuthentication = errors.New("user not authenticated")
|
||||
// ErrOrgMembership means that the user is not in the OAuth2 filtered group
|
||||
ErrOrgMembership = errors.New("Not a member of the required organization")
|
||||
ErrOrgMembership = errors.New("not a member of the required organization")
|
||||
)
|
||||
|
||||
/* Types */
|
||||
|
|
|
@ -110,6 +110,7 @@ type Server struct {
|
|||
GenericAPIKey string `long:"generic-api-key" description:"JSON lookup key into OpenID UserInfo. (Azure should be userPrincipalName)" default:"email" env:"GENERIC_API_KEY"`
|
||||
GenericInsecure bool `long:"generic-insecure" description:"Whether or not to verify auth-url's tls certificates." env:"GENERIC_INSECURE"`
|
||||
GenericRootCA flags.Filename `long:"generic-root-ca" description:"File location of root ca cert for generic oauth tls verification." env:"GENERIC_ROOT_CA"`
|
||||
OAuthNoPKCE bool `long:"oauth-no-pkce" description:"Disables OAuth PKCE." env:"OAUTH_NO_PKCE"`
|
||||
|
||||
Auth0Domain string `long:"auth0-domain" description:"Subdomain of auth0.com used for Auth0 OAuth2 authentication" env:"AUTH0_DOMAIN"`
|
||||
Auth0ClientID string `long:"auth0-client-id" description:"Auth0 Client ID for OAuth2 support" env:"AUTH0_CLIENT_ID"`
|
||||
|
@ -324,6 +325,9 @@ func processCerts(rootReader io.Reader) (*x509.CertPool, error) {
|
|||
|
||||
return certPool, nil
|
||||
}
|
||||
func (s *Server) createCodeExchange() oauth2.CodeExchange {
|
||||
return oauth2.NewCodeExchange(!s.OAuthNoPKCE, s.TokenSecret)
|
||||
}
|
||||
|
||||
func (s *Server) githubOAuth(logger chronograf.Logger, auth oauth2.Authenticator) (oauth2.Provider, oauth2.Mux, func() error) {
|
||||
gh := oauth2.Github{
|
||||
|
@ -333,7 +337,7 @@ func (s *Server) githubOAuth(logger chronograf.Logger, auth oauth2.Authenticator
|
|||
Logger: logger,
|
||||
}
|
||||
jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL)
|
||||
ghMux := oauth2.NewAuthMux(&gh, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient)
|
||||
ghMux := oauth2.NewAuthMux(&gh, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange())
|
||||
return &gh, ghMux, s.UseGithub
|
||||
}
|
||||
|
||||
|
@ -347,7 +351,7 @@ func (s *Server) googleOAuth(logger chronograf.Logger, auth oauth2.Authenticator
|
|||
Logger: logger,
|
||||
}
|
||||
jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL)
|
||||
goMux := oauth2.NewAuthMux(&google, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient)
|
||||
goMux := oauth2.NewAuthMux(&google, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange())
|
||||
return &google, goMux, s.UseGoogle
|
||||
}
|
||||
|
||||
|
@ -359,7 +363,7 @@ func (s *Server) herokuOAuth(logger chronograf.Logger, auth oauth2.Authenticator
|
|||
Logger: logger,
|
||||
}
|
||||
jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL)
|
||||
hMux := oauth2.NewAuthMux(&heroku, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient)
|
||||
hMux := oauth2.NewAuthMux(&heroku, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange())
|
||||
return &heroku, hMux, s.UseHeroku
|
||||
}
|
||||
|
||||
|
@ -378,7 +382,7 @@ func (s *Server) genericOAuth(logger chronograf.Logger, auth oauth2.Authenticato
|
|||
Logger: logger,
|
||||
}
|
||||
jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL)
|
||||
genMux := oauth2.NewAuthMux(&gen, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient)
|
||||
genMux := oauth2.NewAuthMux(&gen, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange())
|
||||
return &gen, genMux, s.UseGenericOAuth2
|
||||
}
|
||||
|
||||
|
@ -394,7 +398,7 @@ func (s *Server) auth0OAuth(logger chronograf.Logger, auth oauth2.Authenticator)
|
|||
auth0, err := oauth2.NewAuth0(s.Auth0Domain, s.Auth0ClientID, s.Auth0ClientSecret, redirectURL.String(), s.Auth0Organizations, logger)
|
||||
|
||||
jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL)
|
||||
genMux := oauth2.NewAuthMux(&auth0, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient)
|
||||
genMux := oauth2.NewAuthMux(&auth0, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange())
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Error parsing Auth0 domain: err:", err)
|
||||
|
@ -658,7 +662,7 @@ func (s *Server) Serve(ctx context.Context) {
|
|||
}
|
||||
|
||||
if !validBasepath(s.Basepath) {
|
||||
err := fmt.Errorf("Invalid basepath, must follow format \"/mybasepath\"")
|
||||
err := fmt.Errorf("invalid basepath, must follow format \"/mybasepath\"")
|
||||
logger.
|
||||
WithField("component", "server").
|
||||
WithField("basepath", "invalid").
|
||||
|
@ -670,7 +674,7 @@ func (s *Server) Serve(ctx context.Context) {
|
|||
logger.
|
||||
WithField("component", "server").
|
||||
WithField("basepath", "invalid").
|
||||
Error(fmt.Errorf("Failed to validate Oauth settings: %s", err))
|
||||
Error(fmt.Errorf("failed to validate Oauth settings: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue