mirror of https://github.com/nucypher/nucypher.git
237 lines
7.4 KiB
Python
237 lines
7.4 KiB
Python
import calendar
|
|
from datetime import datetime, timezone
|
|
|
|
import jwt
|
|
import pytest
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from marshmallow import validates
|
|
|
|
from nucypher.policy.conditions.base import ExecutionCall
|
|
from nucypher.policy.conditions.exceptions import InvalidCondition, JWTException
|
|
from nucypher.policy.conditions.jwt import JWTCondition, JWTVerificationCall
|
|
|
|
TEST_ECDSA_PRIVATE_KEY_RAW_B64 = (
|
|
"MHcCAQEEIHAhM7P6HG3LgkDvgvfDeaMA6uELj+jEKWsSeOpS/SfYoAoGCCqGSM49\n"
|
|
"AwEHoUQDQgAEXHVxB7s5SR7I9cWwry/JkECIRekaCwG3uOLCYbw5gVzn4dRmwMyY\n"
|
|
"UJFcQWuFSfECRK+uQOOXD0YSEucBq0p5tA=="
|
|
)
|
|
|
|
TEST_ECDSA_PRIVATE_KEY = ( # TODO: Workaround to bypass pre-commit hook that detects private keys in code
|
|
"-----BEGIN EC"
|
|
+ " PRIVATE KEY"
|
|
+ f"-----\n{TEST_ECDSA_PRIVATE_KEY_RAW_B64}\n-----END EC"
|
|
+ " PRIVATE KEY-----"
|
|
)
|
|
|
|
TEST_ECDSA_PUBLIC_KEY = (
|
|
"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEXHVxB7s5SR7I9cWwry"
|
|
"/JkECIReka\nCwG3uOLCYbw5gVzn4dRmwMyYUJFcQWuFSfECRK+uQOOXD0YSEucBq0p5tA==\n-----END PUBLIC "
|
|
"KEY-----"
|
|
)
|
|
|
|
ISSUED_AT = calendar.timegm(datetime.now(tz=timezone.utc).utctimetuple())
|
|
|
|
TEST_JWT_TOKEN = jwt.encode(
|
|
{"iat": ISSUED_AT}, TEST_ECDSA_PRIVATE_KEY, algorithm="ES256"
|
|
)
|
|
|
|
|
|
def generate_pem_keypair(elliptic_curve):
|
|
# Generate an EC private key
|
|
private_key = ec.generate_private_key(elliptic_curve)
|
|
|
|
# Get the corresponding public key
|
|
public_key = private_key.public_key()
|
|
|
|
# Serialize the private key to PEM format
|
|
pem_private_key = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Serialize the public key to PEM format
|
|
pem_public_key = public_key.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
).decode("utf-8")
|
|
|
|
return pem_public_key, pem_private_key
|
|
|
|
|
|
def jwt_token(
|
|
with_iat: bool = True, claims: dict = None, expiration_offset: int = None
|
|
):
|
|
claims = claims or dict()
|
|
if with_iat:
|
|
claims["iat"] = ISSUED_AT
|
|
if expiration_offset is not None:
|
|
claims["exp"] = ISSUED_AT + expiration_offset
|
|
|
|
return jwt.encode(claims, TEST_ECDSA_PRIVATE_KEY, algorithm="ES256")
|
|
|
|
|
|
class TestJWTVerificationCall(JWTVerificationCall):
|
|
class Schema(JWTVerificationCall.Schema):
|
|
@validates("jwt_token")
|
|
def validate_jwt_token(self, value):
|
|
pass
|
|
|
|
|
|
def test_raw_jwt_decode():
|
|
token = jwt_token()
|
|
|
|
# Valid JWT
|
|
jwt.decode(token, TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
|
|
|
|
# Invalid JWT
|
|
with pytest.raises(jwt.exceptions.InvalidTokenError):
|
|
jwt.decode(token[1:], TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
|
|
|
|
|
|
def test_jwt_verification_call_invalid():
|
|
token = jwt_token()
|
|
message = r"Invalid value for JWT token; expected a context variable"
|
|
with pytest.raises(ExecutionCall.InvalidExecutionCall, match=message):
|
|
JWTVerificationCall(jwt_token=token, public_key=TEST_ECDSA_PUBLIC_KEY)
|
|
|
|
|
|
def test_jwt_verification_call_valid():
|
|
token = jwt_token()
|
|
call = TestJWTVerificationCall(jwt_token=token, public_key=TEST_ECDSA_PUBLIC_KEY)
|
|
assert call.execute()
|
|
|
|
|
|
def test_jwt_condition_missing_jwt_token():
|
|
with pytest.raises(
|
|
InvalidCondition, match="'jwt_token' field - Field may not be null."
|
|
):
|
|
_ = JWTCondition(jwt_token=None, public_key=None)
|
|
|
|
|
|
def test_jwt_condition_missing_public_key():
|
|
with pytest.raises(
|
|
InvalidCondition, match="'public_key' field - Field may not be null."
|
|
):
|
|
_ = JWTCondition(
|
|
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt", public_key=None
|
|
)
|
|
|
|
|
|
def test_jwt_condition_invalid_public_key():
|
|
with pytest.raises(
|
|
InvalidCondition,
|
|
match="'public_key' field - Invalid public key format: Unable to load PEM",
|
|
):
|
|
_ = JWTCondition(
|
|
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
|
|
public_key="-----BEGIN PUBLIC KEY----- haha, gotcha! 👌 -----END PUBLIC KEY-----",
|
|
)
|
|
|
|
|
|
def test_jwt_condition_but_unsupported_public_key():
|
|
pem_secp521_public_key, _ = generate_pem_keypair(ec.SECP521R1())
|
|
|
|
with pytest.raises(
|
|
InvalidCondition,
|
|
match="'public_key' field - Invalid public key format: Invalid EC public key curve",
|
|
):
|
|
_ = JWTCondition(
|
|
jwt_token=":ok_ok_this_is_a_variable_for_a_jwt",
|
|
public_key=pem_secp521_public_key,
|
|
)
|
|
|
|
|
|
def test_jwt_condition_initialization():
|
|
condition = JWTCondition(
|
|
jwt_token=":aContextVariableForJWTs",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
)
|
|
|
|
assert condition.jwt_token == ":aContextVariableForJWTs"
|
|
assert condition.public_key == TEST_ECDSA_PUBLIC_KEY
|
|
assert condition.condition_type == JWTCondition.CONDITION_TYPE
|
|
|
|
|
|
def test_jwt_condition_verify():
|
|
token = jwt_token(with_iat=False)
|
|
condition = JWTCondition(
|
|
jwt_token=":anotherContextVariableForJWTs",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
)
|
|
|
|
context = {":anotherContextVariableForJWTs": token}
|
|
success, result = condition.verify(**context)
|
|
assert success
|
|
assert result == {}
|
|
|
|
|
|
def test_jwt_condition_verify_of_jwt_with_custom_claims():
|
|
token = jwt_token(with_iat=False, claims={"foo": "bar"})
|
|
condition = JWTCondition(
|
|
jwt_token=":anotherContextVariableForJWTs",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
)
|
|
|
|
context = {":anotherContextVariableForJWTs": token}
|
|
success, result = condition.verify(**context)
|
|
assert success
|
|
assert result == {"foo": "bar"}
|
|
|
|
|
|
def test_jwt_condition_verify_with_correct_issuer():
|
|
token = jwt_token(with_iat=False, claims={"iss": "Isabel"})
|
|
condition = JWTCondition(
|
|
jwt_token=":anotherContextVariableForJWTs",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
expected_issuer="Isabel",
|
|
)
|
|
|
|
context = {":anotherContextVariableForJWTs": token}
|
|
success, result = condition.verify(**context)
|
|
assert success
|
|
assert result == {"iss": "Isabel"}
|
|
|
|
|
|
def test_jwt_condition_verify_with_invalid_issuer():
|
|
token = jwt_token(with_iat=False, claims={"iss": "Isabel"})
|
|
condition = JWTCondition(
|
|
jwt_token=":anotherContextVariableForJWTs",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
expected_issuer="Isobel",
|
|
)
|
|
|
|
context = {":anotherContextVariableForJWTs": token}
|
|
with pytest.raises(JWTException, match="Invalid issuer"):
|
|
_ = condition.verify(**context)
|
|
|
|
|
|
def test_jwt_condition_verify_expired_token():
|
|
# Create a token that expired 100 seconds
|
|
expired_token = jwt_token(with_iat=True, expiration_offset=-100)
|
|
|
|
condition = JWTCondition(
|
|
jwt_token=":contextVar",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
)
|
|
|
|
context = {":contextVar": expired_token}
|
|
with pytest.raises(JWTException, match="Signature has expired"):
|
|
_ = condition.verify(**context)
|
|
|
|
|
|
def test_jwt_condition_verify_valid_token_with_expiration():
|
|
# Create a token that will expire in 999 seconds
|
|
expired_token = jwt_token(with_iat=False, expiration_offset=999)
|
|
|
|
condition = JWTCondition(
|
|
jwt_token=":contextVar",
|
|
public_key=TEST_ECDSA_PUBLIC_KEY,
|
|
)
|
|
|
|
context = {":contextVar": expired_token}
|
|
success, result = condition.verify(**context)
|
|
assert success
|
|
assert result == {"exp": ISSUED_AT + 999}
|