Validate public key PEM format in JWT conditions

Co-authored-by: James Campbell <james.campbell@tanti.org.uk>
pull/3586/head
David Núñez 2025-01-07 16:37:58 +01:00
parent 05998b0a9a
commit 9c8db5e688
2 changed files with 86 additions and 7 deletions

View File

@ -1,6 +1,9 @@
from typing import Any, Optional, Tuple
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from marshmallow import ValidationError, fields, post_load, validate, validates
from nucypher.policy.conditions.base import ExecutionCall
@ -23,6 +26,8 @@ class JWTVerificationCall(ExecutionCall):
"RS256",
) # https://datatracker.ietf.org/doc/html/rfc7518#section-3.1
SECP_CURVE_FOR_ES256 = "secp256r1"
class Schema(ExecutionCall.Schema):
jwt_token = fields.Str(required=True)
# TODO: See #3572 for a discussion about deprecating this in favour of the expected issuer
@ -43,6 +48,25 @@ class JWTVerificationCall(ExecutionCall):
f"Invalid value for JWT token; expected a context variable, but got '{value}'"
)
@validates("public_key")
def validate_public_key(self, value):
try:
public_key = load_pem_public_key(
value.encode(), backend=default_backend()
)
if isinstance(public_key, rsa.RSAPublicKey):
return value
elif isinstance(public_key, ec.EllipticCurvePublicKey):
curve = public_key.curve
if curve.name != JWTVerificationCall.SECP_CURVE_FOR_ES256:
raise ValidationError(
f"Invalid EC public key curve: {curve.name}"
)
except Exception as e:
raise ValidationError(f"Invalid public key format: {str(e)}")
return value
def __init__(
self,
jwt_token: Optional[str] = None,

View File

@ -2,9 +2,12 @@ from datetime import datetime, timezone
import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from marshmallow import validates
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.jwt import JWTCondition, JWTVerificationCall
TEST_ECDSA_PRIVATE_KEY_RAW_B64 = (
@ -33,6 +36,29 @@ TEST_JWT_TOKEN = jwt.encode(
)
def generate_pem_keypair(elliptic_curve):
# Generate an EC private key
private_key = ec.generate_private_key(elliptic_curve)
# Get the corresponding public key
public_key = private_key.public_key()
# Serialize the private key to PEM format
pem_private_key = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
# Serialize the public key to PEM format
pem_public_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")
return pem_public_key, pem_private_key
def jwt_token(with_iat: bool = True, claims: dict = None):
claims = claims or dict()
if with_iat:
@ -72,13 +98,42 @@ def test_jwt_verification_call_valid():
assert call.execute()
def test_jwt_verification_call_invalid_issuer():
token = jwt_token(with_iat=False, claims={"iss": "Isabel"})
call = TestJWTVerificationCall(
jwt_token=token, public_key=TEST_ECDSA_PUBLIC_KEY, expected_issuer="Isabel"
)
payload = call.execute()
assert payload == {"iss": "Isabel"}
def test_jwt_condition_missing_jwt_token():
with pytest.raises(
InvalidCondition, match="'jwt_token' field - Field may not be null."
):
_ = JWTCondition()
def test_jwt_condition_missing_public_key():
with pytest.raises(
InvalidCondition, match="'public_key' field - Field may not be null."
):
_ = JWTCondition(jwt_token=":ok_ok_this_is_a_variable_for_a_jwt")
def test_jwt_condition_invalid_public_key():
with pytest.raises(
InvalidCondition,
match="'public_key' field - Invalid public key format: Unable to load PEM",
):
_ = JWTCondition(
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
public_key="-----BEGIN PUBLIC KEY----- haha, gotcha! 👌 -----END PUBLIC KEY-----",
)
def test_jwt_condition_but_unsupported_public_key():
pem_secp521_public_key, _ = generate_pem_keypair(ec.SECP521R1())
with pytest.raises(
InvalidCondition,
match="'public_key' field - Invalid public key format: Invalid EC public key curve",
):
_ = JWTCondition(
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
public_key=pem_secp521_public_key,
)
def test_jwt_condition_initialization():