242 lines
7.4 KiB
Go
242 lines
7.4 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// CodeExchange helps to ensure secure exchange
|
|
// of the authorization code for token.
|
|
type CodeExchange interface {
|
|
// AuthCodeURL generates authorization URL with a state that prevents CSRF attacks.
|
|
// It can also use OAuth2 PKCE.
|
|
AuthCodeURL(ctx context.Context, j *AuthMux) (string, error)
|
|
// ExchangeCodeForToken receives authorization token having a verified state and code
|
|
ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error)
|
|
}
|
|
|
|
// default implementation
|
|
var simpleTokenExchange CodeExchange = &CodeExchangeCSRF{}
|
|
|
|
func NewCodeExchange(withPKCE bool, secret string) CodeExchange {
|
|
if withPKCE {
|
|
return &CodeExchangePKCE{Secret: secret}
|
|
}
|
|
return simpleTokenExchange
|
|
}
|
|
|
|
// CodeExchangeCSRF prevents CSRF attacks during retrieval of OAuth token
|
|
// by using a signed random state in the exchange with authorization server.
|
|
// It uses a random string as the state validation method. The state is a JWT. It is
|
|
// a good choice here for encoding because they can be validated without
|
|
// storing state.
|
|
type CodeExchangeCSRF struct {
|
|
}
|
|
|
|
// AuthCodeURL generates authorization URL with a state that prevents CSRF attacks.
|
|
func (p *CodeExchangeCSRF) AuthCodeURL(ctx context.Context, j *AuthMux) (string, error) {
|
|
// We are creating a token with an encoded random string to prevent CSRF attacks
|
|
// This token will be validated during the OAuth callback.
|
|
// We'll give our users 10 minutes from this point to type in their
|
|
// oauth2 provider's password.
|
|
// If the callback is not received within 10 minutes, then authorization will fail.
|
|
csrf := make([]byte, 32) // 32 is not important... just long
|
|
if _, err := io.ReadFull(rand.Reader, csrf); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
now := j.Now()
|
|
|
|
// This token will be valid for 10 minutes. Any chronograf server will
|
|
// be able to validate this token.
|
|
stateData := Principal{
|
|
Subject: base64.RawURLEncoding.EncodeToString(csrf),
|
|
IssuedAt: now,
|
|
ExpiresAt: now.Add(TenMinutes),
|
|
}
|
|
token, err := j.Tokens.Create(ctx, stateData)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
urlOpts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}
|
|
if j.LoginHint != "" {
|
|
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
|
|
}
|
|
url := j.Provider.Config().AuthCodeURL(string(token), urlOpts...)
|
|
return url, nil
|
|
}
|
|
|
|
func (p *CodeExchangeCSRF) ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error) {
|
|
// Check if the OAuth state token is valid to prevent CSRF
|
|
// The state variable we set is actually a token. We'll check
|
|
// if the token is valid. We don't need to know anything
|
|
// about the contents of the principal only that it hasn't expired.
|
|
if _, err := j.Tokens.ValidPrincipal(ctx, Token(state), TenMinutes); err != nil {
|
|
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
|
}
|
|
|
|
// Exchange the code back with the provider to the the token
|
|
conf := j.Provider.Config()
|
|
|
|
// Use http client with transport options.
|
|
if j.client != nil {
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, j.client)
|
|
}
|
|
|
|
return conf.Exchange(ctx, code)
|
|
}
|
|
|
|
// CodeExchangePKCE extends CodeExchangeCSRF and adds OAuth2 PKCE
|
|
// to protect against interception attacks. See RFC7636 for details.
|
|
type CodeExchangePKCE struct {
|
|
// Secret is used encrypt PKCE code verifier in the state data. The state
|
|
// data are chosen to avoid the need for the server to remember the sate.
|
|
Secret string
|
|
}
|
|
|
|
// Encrypt encrypts the codeVerifier supplied
|
|
func (c *CodeExchangePKCE) Encrypt(codeVerifier []byte) (string, error) {
|
|
// create a AES256 key out of secret
|
|
key := sha256.Sum256([]byte(c.Secret))
|
|
|
|
// create a cipher
|
|
block, err := aes.NewCipher(key[:32])
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
aesGCM, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// create a nonce for the cipher
|
|
nonce := make([]byte, aesGCM.NonceSize())
|
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// encrypt the data
|
|
cipherText := aesGCM.Seal(nonce, nonce, codeVerifier, nil)
|
|
return base64.RawURLEncoding.EncodeToString(cipherText), nil
|
|
}
|
|
|
|
// Decrypt decrypts the supplied encrypted string
|
|
func (c *CodeExchangePKCE) Decrypt(encrypted string) ([]byte, error) {
|
|
// create a AES256 key out of secret
|
|
key := sha256.Sum256([]byte(c.Secret))
|
|
enc, err := base64.RawURLEncoding.Strict().DecodeString(encrypted)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
// create a new cipher
|
|
block, err := aes.NewCipher(key[:32])
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
aesGCM, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
// separate nonce and cipherText
|
|
nonceSize := aesGCM.NonceSize()
|
|
if len(enc) <= nonceSize {
|
|
return []byte{}, errors.New("malformed encrypted data")
|
|
}
|
|
nonce, cipherText := enc[:nonceSize], enc[nonceSize:]
|
|
|
|
// decrypt
|
|
plain, err := aesGCM.Open(nil, nonce, cipherText, nil)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
return plain, nil
|
|
|
|
}
|
|
|
|
// AuthCodeURL generates authorization URL with PKCE
|
|
// challenge parameters and a state that prevents CSRF and.
|
|
func (p *CodeExchangePKCE) AuthCodeURL(ctx context.Context, j *AuthMux) (string, error) {
|
|
// generate code verifier
|
|
codeVerifier := make([]byte, 32)
|
|
if _, err := io.ReadFull(rand.Reader, codeVerifier); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// encrypt code verifier so it can be added to state,
|
|
// we don't need to remember it on the server side then
|
|
encryptedCodeVerifier, err := p.Encrypt(codeVerifier)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
now := j.Now()
|
|
// This token will be valid for 10 minutes. Any chronograf server will
|
|
// be able to validate this token.
|
|
stateData := Principal{
|
|
Subject: encryptedCodeVerifier,
|
|
IssuedAt: now,
|
|
ExpiresAt: now.Add(TenMinutes),
|
|
}
|
|
token, err := j.Tokens.Create(ctx, stateData)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// setup URL options including PKCE code challenge
|
|
codeVerifierDigest := sha256.Sum256([]byte(
|
|
base64.RawURLEncoding.EncodeToString(codeVerifier),
|
|
))
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(
|
|
codeVerifierDigest[:],
|
|
)
|
|
urlOpts := make([]oauth2.AuthCodeOption, 0, 4)
|
|
urlOpts = append(urlOpts,
|
|
oauth2.AccessTypeOnline,
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
if j.LoginHint != "" {
|
|
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
|
|
}
|
|
url := j.Provider.Config().AuthCodeURL(string(token), urlOpts...)
|
|
return url, nil
|
|
}
|
|
|
|
func (p *CodeExchangePKCE) ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error) {
|
|
// Check if the OAuth state token is valid.
|
|
stateData, err := j.Tokens.ValidPrincipal(ctx, Token(state), TenMinutes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
|
}
|
|
codeVerifier, err := p.Decrypt(stateData.Subject)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
|
|
}
|
|
|
|
// Exchange the code back with the provider to the the token
|
|
conf := j.Provider.Config()
|
|
|
|
// Use http client with transport options.
|
|
if j.client != nil {
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, j.client)
|
|
}
|
|
|
|
// exhange code for token with PKCE code_verifier supplied
|
|
return conf.Exchange(ctx,
|
|
code,
|
|
oauth2.SetAuthURLParam("code_verifier", base64.RawURLEncoding.EncodeToString(codeVerifier)),
|
|
)
|
|
}
|