First iteration on JWTConditions

pull/3570/head
David Núñez 2024-12-12 18:21:20 +01:00
parent 2c805ee0d9
commit 665ade0658
No known key found for this signature in database
GPG Key ID: 53A9D83EF4C6332A
2 changed files with 234 additions and 0 deletions

View File

@ -0,0 +1,168 @@
from typing import Any, Optional, Tuple
import jwt
from marshmallow import ValidationError, fields, post_load, validate, validates
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_any_context_variables,
)
from nucypher.policy.conditions.lingo import (
ConditionType,
ExecutionCallAccessControlCondition,
ReturnValueTest,
)
from nucypher.utilities.logging import Logger
class JWTVerificationCall(ExecutionCall):
_valid_jwt_algorithms = (
"ES256",
"RS256",
) # https://datatracker.ietf.org/doc/html/rfc7518#section-3.1
class Schema(ExecutionCall.Schema):
jwt_token = fields.Str(required=True) # TODO: validate jwt encoded format
public_key = fields.Str(
required=True
) # required? maybe a valid PK certificate passed by requester?
expected_issuer = fields.Str(required=False, allow_none=True)
# subject = fields.Str(required=False)
# expiration_window = fields.Int(
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
# )
# issued_window = fields.Int(
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
# )
# # todo: kid (https://www.rfc-editor.org/rfc/rfc7515#section-4.1.4), x5u, etc
@post_load
def make(self, data, **kwargs):
return JWTVerificationCall(**data)
@validates("jwt_token")
def validate_jwt_token(self, value):
if value and not is_context_variable(value):
raise ValidationError(
f"Invalid value for JWT token; expected a context variable, but got '{value}'"
)
def __init__(
self,
jwt_token: Optional[str] = None,
public_key: Optional[str] = None,
expected_issuer: Optional[str] = None,
# subject: Optional[str] = None,
# expiration_window: Optional[int] = None,
# issued_window: Optional[int] = None,
):
self.jwt_token = jwt_token
self.public_key = public_key
self.expected_issuer = expected_issuer
# self.subject = subject
# self.expiration = expiration_window
# self.issued_window = issued_window
self.logger = Logger(__name__)
super().__init__()
def execute(self, **context) -> Any:
jwt_token = resolve_any_context_variables(self.jwt_token, **context)
# header = jwt.get_unverified_header(self.jwt_token)
# algorithm = header['alg']
try:
payload = jwt.decode(
jwt_token, self.public_key, algorithms=self._valid_jwt_algorithms
)
except jwt.exceptions.InvalidAlgorithmError:
raise # TODO: raise something specific
except jwt.exceptions.DecodeError:
raise
return payload
class JWTCondition(ExecutionCallAccessControlCondition):
"""
A JWT condition can be satisfied by presenting a valid JWT token, which not only is
required to be cryptographically verifiable, but also must fulfill certain additional
restrictions defined in the condition.
"""
EXECUTION_CALL_TYPE = JWTVerificationCall
CONDITION_TYPE = ConditionType.JWT.value
class Schema(
ExecutionCallAccessControlCondition.Schema, JWTVerificationCall.Schema
):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.JWT.value), required=True
)
# jwt_token = fields.Str(required=True) # TODO: validate jwt encoded format
# public_key = fields.Str(required=True) # required? maybe a valid PK certificate passed by requester?
# expected_issuer = fields.Str(required=False)
# subject = fields.Str(required=False)
# expiration_window = fields.Int(
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
# )
# issued_window = fields.Int(
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
# )
# todo: kid (https://www.rfc-editor.org/rfc/rfc7515#section-4.1.4), x5u, etc
@post_load
def make(self, data, **kwargs):
return JWTCondition(**data)
def __init__(
self,
condition_type: str = ConditionType.JWT.value,
name: Optional[str] = None,
jwt_token: Optional[str] = None,
public_key: Optional[str] = None,
expected_issuer: Optional[str] = None,
# subject: Optional[str] = None,
# expiration_window: Optional[int] = None,
# issued_window: Optional[int] = None,
):
super().__init__(
jwt_token=jwt_token,
public_key=public_key,
expected_issuer=expected_issuer,
# subject=subject,
# expiration=expiration_window,
# issued_window=issued_window,
condition_type=condition_type,
name=name,
return_value_test=ReturnValueTest(
comparator="==", value=True
), # TODO: Workaround for now
)
@property
def jwt_token(self):
return self.execution_call.jwt_token
@property
def public_key(self):
return self.execution_call.public_key
@property
def expected_issuer(self):
return self.execution_call.expected_issuer
#
def verify(self, **context) -> Tuple[bool, Any]:
try:
payload = self.execution_call.execute(**context)
result = True # TODO: Additional condition checks
except Exception: # TODO: specific exceptions
payload = None
result = False
return result, payload

View File

@ -0,0 +1,66 @@
import jwt
import pytest
from marshmallow import validates
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.jwt import JWTCondition, JWTVerificationCall
TEST_ECDSA_PUBLIC_KEY = (
"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEXHVxB7s5SR7I9cWwry"
"/JkECIReka\nCwG3uOLCYbw5gVzn4dRmwMyYUJFcQWuFSfECRK+uQOOXD0YSEucBq0p5tA==\n-----END PUBLIC "
"KEY-----\n "
)
TEST_JWT_TOKEN = (
"eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9"
".eyJpYXQiOjE3MzM0MjQ3MTd9"
".uc2Av6f4yibXRLtmCmvhbRiNfYTrkHPS3vAGHaamX1CQ4mQR8iGyE8X3TvseCclkgsbKBBKZG8nQXhA5hsXLRg"
)
class TestJWTVerificationCall(JWTVerificationCall):
class Schema(JWTVerificationCall.Schema):
@validates("jwt_token")
def validate_jwt_token(self, value):
pass
def test_raw_jwt_decode():
# Valid JWT
jwt.decode(TEST_JWT_TOKEN, TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
# Invalid JWT
with pytest.raises(jwt.exceptions.InvalidTokenError):
jwt.decode(TEST_JWT_TOKEN[1:], TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
def test_jwt_verification_call_invalid():
message = r"Invalid value for JWT token; expected a context variable"
with pytest.raises(ExecutionCall.InvalidExecutionCall, match=message):
JWTVerificationCall(jwt_token=TEST_JWT_TOKEN, public_key=TEST_ECDSA_PUBLIC_KEY)
def test_jwt_verification_call_invalid2():
TestJWTVerificationCall(jwt_token=TEST_JWT_TOKEN, public_key=TEST_ECDSA_PUBLIC_KEY)
def test_jwt_condition_initialization():
condition = JWTCondition(
jwt_token=":aContextVariableForJWTs",
public_key=TEST_ECDSA_PUBLIC_KEY,
)
assert condition.jwt_token == ":aContextVariableForJWTs"
assert condition.public_key == TEST_ECDSA_PUBLIC_KEY
assert condition.condition_type == JWTCondition.CONDITION_TYPE
def test_jwt_condition_verify():
condition = JWTCondition(
jwt_token=":anotherContextVariableForJWTs",
public_key=TEST_ECDSA_PUBLIC_KEY,
)
context = {":anotherContextVariableForJWTs": TEST_JWT_TOKEN}
success, result = condition.verify(**context)
assert success
assert result is not None