feat(oauth): add stateless PKCE CodeExchange impl

pull/5710/head
Pavel Zavora 2021-03-22 22:00:50 +01:00
parent 0e8a36b669
commit e6e6b78c91
1 changed files with 156 additions and 10 deletions

View File

@ -2,8 +2,12 @@ package oauth2
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
@ -23,7 +27,10 @@ type CodeExchange interface {
// default implementation
var simpleTokenExchange CodeExchange = &CodeExchangeCSRF{}
func NewCodeExchange(withPKCE bool) CodeExchange {
func NewCodeExchange(withPKCE bool, secret string) CodeExchange {
if withPKCE {
return &CodeExchangePKCE{Secret: secret}
}
return simpleTokenExchange
}
@ -42,13 +49,17 @@ func (p *CodeExchangeCSRF) AuthCodeURL(ctx context.Context, j *AuthMux) (string,
// We'll give our users 10 minutes from this point to type in their
// oauth2 provider's password.
// If the callback is not received within 10 minutes, then authorization will fail.
csrf := randomString(32) // 32 is not important... just long
csrf := make([]byte, 32) // 32 is not important... just long
if _, err := io.ReadFull(rand.Reader, csrf); err != nil {
return "", err
}
now := j.Now()
// This token will be valid for 10 minutes. Any chronograf server will
// be able to validate this token.
stateData := Principal{
Subject: csrf,
Subject: base64.RawURLEncoding.EncodeToString(csrf),
IssuedAt: now,
ExpiresAt: now.Add(TenMinutes),
}
@ -85,11 +96,146 @@ func (p *CodeExchangeCSRF) ExchangeCodeForToken(ctx context.Context, state, code
return conf.Exchange(ctx, code)
}
func randomString(length int) string {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return ""
}
// use the simplest encoding that is also required by PKCE
return base64.RawURLEncoding.EncodeToString(k)
// CodeExchangePKCE extends CodeExchangeCSRF and adds OAuth2 PKCE
// to protect against interception attacks. See RFC7636 for details.
type CodeExchangePKCE struct {
// Secret is used encrypt PKCE code verifier in the state data. The state
// data are chosen to avoid the need for the server to remember the sate.
Secret string
}
// Encrypt encrypts the codeVerifier supplied
func (c *CodeExchangePKCE) Encrypt(codeVerifier []byte) (string, error) {
// create a AES256 key out of secret
key := sha256.Sum256([]byte(c.Secret))
// create a cipher
block, err := aes.NewCipher(key[:32])
if err != nil {
return "", err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// create a nonce for the cipher
nonce := make([]byte, aesGCM.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// encrypt the data
cipherText := aesGCM.Seal(nonce, nonce, codeVerifier, nil)
return base64.RawURLEncoding.EncodeToString(cipherText), nil
}
// Decrypt decrypts the supplied encrypted string
func (c *CodeExchangePKCE) Decrypt(encrypted string) ([]byte, error) {
// create a AES256 key out of secret
key := sha256.Sum256([]byte(c.Secret))
enc, err := base64.RawURLEncoding.Strict().DecodeString(encrypted)
if err != nil {
return []byte{}, err
}
// create a new cipher
block, err := aes.NewCipher(key[:32])
if err != nil {
return []byte{}, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return []byte{}, err
}
// separate nonce and cipherText
nonceSize := aesGCM.NonceSize()
if len(enc) <= nonceSize {
return []byte{}, errors.New("malformed encrypted data")
}
nonce, cipherText := enc[:nonceSize], enc[nonceSize:]
// decrypt
plain, err := aesGCM.Open(nil, nonce, cipherText, nil)
if err != nil {
return []byte{}, err
}
return plain, nil
}
// AuthCodeURL generates authorization URL with PKCE
// challenge parameters and a state that prevents CSRF and.
func (p *CodeExchangePKCE) AuthCodeURL(ctx context.Context, j *AuthMux) (string, error) {
// generate code verifier
codeVerifier := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, codeVerifier); err != nil {
return "", err
}
// encrypt code verifier so it can be added to state,
// we don't need to remember it on the server side then
encryptedCodeVerifier, err := p.Encrypt(codeVerifier)
if err != nil {
return "", err
}
now := j.Now()
// This token will be valid for 10 minutes. Any chronograf server will
// be able to validate this token.
stateData := Principal{
Subject: encryptedCodeVerifier,
IssuedAt: now,
ExpiresAt: now.Add(TenMinutes),
}
token, err := j.Tokens.Create(ctx, stateData)
if err != nil {
return "", err
}
// setup URL options including PKCE code challenge
codeVerifierDigest := sha256.Sum256([]byte(
base64.RawURLEncoding.EncodeToString(codeVerifier),
))
codeChallenge := base64.RawURLEncoding.EncodeToString(
codeVerifierDigest[:],
)
urlOpts := make([]oauth2.AuthCodeOption, 0, 4)
urlOpts = append(urlOpts,
oauth2.AccessTypeOnline,
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
if j.LoginHint != "" {
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("login_hint", j.LoginHint))
}
url := j.Provider.Config().AuthCodeURL(string(token), urlOpts...)
return url, nil
}
func (p *CodeExchangePKCE) ExchangeCodeForToken(ctx context.Context, state, code string, j *AuthMux) (*oauth2.Token, error) {
// Check if the OAuth state token is valid.
stateData, err := j.Tokens.ValidPrincipal(ctx, Token(state), TenMinutes)
if err != nil {
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
}
codeVerifier, err := p.Decrypt(stateData.Subject)
if err != nil {
return nil, fmt.Errorf("invalid OAuth state received: %v", err.Error())
}
// Exchange the code back with the provider to the the token
conf := j.Provider.Config()
// Use http client with transport options.
if j.client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, j.client)
}
// exhange code for token with PKCE code_verifier supplied
return conf.Exchange(ctx,
code,
oauth2.SetAuthURLParam("code_verifier", base64.RawURLEncoding.EncodeToString(codeVerifier)),
)
}