mirror of https://github.com/nucypher/nucypher.git
Validate public key PEM format in JWT conditions
Co-authored-by: James Campbell <james.campbell@tanti.org.uk>pull/3586/head
parent
05998b0a9a
commit
9c8db5e688
|
@ -1,6 +1,9 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates
|
||||
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
|
@ -23,6 +26,8 @@ class JWTVerificationCall(ExecutionCall):
|
|||
"RS256",
|
||||
) # https://datatracker.ietf.org/doc/html/rfc7518#section-3.1
|
||||
|
||||
SECP_CURVE_FOR_ES256 = "secp256r1"
|
||||
|
||||
class Schema(ExecutionCall.Schema):
|
||||
jwt_token = fields.Str(required=True)
|
||||
# TODO: See #3572 for a discussion about deprecating this in favour of the expected issuer
|
||||
|
@ -43,6 +48,25 @@ class JWTVerificationCall(ExecutionCall):
|
|||
f"Invalid value for JWT token; expected a context variable, but got '{value}'"
|
||||
)
|
||||
|
||||
@validates("public_key")
|
||||
def validate_public_key(self, value):
|
||||
try:
|
||||
public_key = load_pem_public_key(
|
||||
value.encode(), backend=default_backend()
|
||||
)
|
||||
if isinstance(public_key, rsa.RSAPublicKey):
|
||||
return value
|
||||
elif isinstance(public_key, ec.EllipticCurvePublicKey):
|
||||
curve = public_key.curve
|
||||
if curve.name != JWTVerificationCall.SECP_CURVE_FOR_ES256:
|
||||
raise ValidationError(
|
||||
f"Invalid EC public key curve: {curve.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid public key format: {str(e)}")
|
||||
|
||||
return value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jwt_token: Optional[str] = None,
|
||||
|
|
|
@ -2,9 +2,12 @@ from datetime import datetime, timezone
|
|||
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from marshmallow import validates
|
||||
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.jwt import JWTCondition, JWTVerificationCall
|
||||
|
||||
TEST_ECDSA_PRIVATE_KEY_RAW_B64 = (
|
||||
|
@ -33,6 +36,29 @@ TEST_JWT_TOKEN = jwt.encode(
|
|||
)
|
||||
|
||||
|
||||
def generate_pem_keypair(elliptic_curve):
|
||||
# Generate an EC private key
|
||||
private_key = ec.generate_private_key(elliptic_curve)
|
||||
|
||||
# Get the corresponding public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize the private key to PEM format
|
||||
pem_private_key = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
|
||||
# Serialize the public key to PEM format
|
||||
pem_public_key = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode("utf-8")
|
||||
|
||||
return pem_public_key, pem_private_key
|
||||
|
||||
|
||||
def jwt_token(with_iat: bool = True, claims: dict = None):
|
||||
claims = claims or dict()
|
||||
if with_iat:
|
||||
|
@ -72,13 +98,42 @@ def test_jwt_verification_call_valid():
|
|||
assert call.execute()
|
||||
|
||||
|
||||
def test_jwt_verification_call_invalid_issuer():
|
||||
token = jwt_token(with_iat=False, claims={"iss": "Isabel"})
|
||||
call = TestJWTVerificationCall(
|
||||
jwt_token=token, public_key=TEST_ECDSA_PUBLIC_KEY, expected_issuer="Isabel"
|
||||
)
|
||||
payload = call.execute()
|
||||
assert payload == {"iss": "Isabel"}
|
||||
def test_jwt_condition_missing_jwt_token():
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="'jwt_token' field - Field may not be null."
|
||||
):
|
||||
_ = JWTCondition()
|
||||
|
||||
|
||||
def test_jwt_condition_missing_public_key():
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="'public_key' field - Field may not be null."
|
||||
):
|
||||
_ = JWTCondition(jwt_token=":ok_ok_this_is_a_variable_for_a_jwt")
|
||||
|
||||
|
||||
def test_jwt_condition_invalid_public_key():
|
||||
with pytest.raises(
|
||||
InvalidCondition,
|
||||
match="'public_key' field - Invalid public key format: Unable to load PEM",
|
||||
):
|
||||
_ = JWTCondition(
|
||||
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
|
||||
public_key="-----BEGIN PUBLIC KEY----- haha, gotcha! 👌 -----END PUBLIC KEY-----",
|
||||
)
|
||||
|
||||
|
||||
def test_jwt_condition_but_unsupported_public_key():
|
||||
pem_secp521_public_key, _ = generate_pem_keypair(ec.SECP521R1())
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition,
|
||||
match="'public_key' field - Invalid public key format: Invalid EC public key curve",
|
||||
):
|
||||
_ = JWTCondition(
|
||||
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
|
||||
public_key=pem_secp521_public_key,
|
||||
)
|
||||
|
||||
|
||||
def test_jwt_condition_initialization():
|
||||
|
|
Loading…
Reference in New Issue