mirror of https://github.com/nucypher/nucypher.git
First iteration on JWTConditions
parent
c17f174501
commit
bfba37db58
|
@ -0,0 +1,168 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import jwt
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates
|
||||
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.context import (
|
||||
is_context_variable,
|
||||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionType,
|
||||
ExecutionCallAccessControlCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
||||
class JWTVerificationCall(ExecutionCall):
|
||||
|
||||
_valid_jwt_algorithms = (
|
||||
"ES256",
|
||||
"RS256",
|
||||
) # https://datatracker.ietf.org/doc/html/rfc7518#section-3.1
|
||||
|
||||
class Schema(ExecutionCall.Schema):
|
||||
jwt_token = fields.Str(required=True) # TODO: validate jwt encoded format
|
||||
public_key = fields.Str(
|
||||
required=True
|
||||
) # required? maybe a valid PK certificate passed by requester?
|
||||
expected_issuer = fields.Str(required=False, allow_none=True)
|
||||
# subject = fields.Str(required=False)
|
||||
# expiration_window = fields.Int(
|
||||
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
|
||||
# )
|
||||
# issued_window = fields.Int(
|
||||
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
|
||||
# )
|
||||
# # todo: kid (https://www.rfc-editor.org/rfc/rfc7515#section-4.1.4), x5u, etc
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return JWTVerificationCall(**data)
|
||||
|
||||
@validates("jwt_token")
|
||||
def validate_jwt_token(self, value):
|
||||
if value and not is_context_variable(value):
|
||||
raise ValidationError(
|
||||
f"Invalid value for JWT token; expected a context variable, but got '{value}'"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jwt_token: Optional[str] = None,
|
||||
public_key: Optional[str] = None,
|
||||
expected_issuer: Optional[str] = None,
|
||||
# subject: Optional[str] = None,
|
||||
# expiration_window: Optional[int] = None,
|
||||
# issued_window: Optional[int] = None,
|
||||
):
|
||||
self.jwt_token = jwt_token
|
||||
self.public_key = public_key
|
||||
self.expected_issuer = expected_issuer
|
||||
# self.subject = subject
|
||||
# self.expiration = expiration_window
|
||||
# self.issued_window = issued_window
|
||||
|
||||
self.logger = Logger(__name__)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def execute(self, **context) -> Any:
|
||||
|
||||
jwt_token = resolve_any_context_variables(self.jwt_token, **context)
|
||||
|
||||
# header = jwt.get_unverified_header(self.jwt_token)
|
||||
# algorithm = header['alg']
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
jwt_token, self.public_key, algorithms=self._valid_jwt_algorithms
|
||||
)
|
||||
except jwt.exceptions.InvalidAlgorithmError:
|
||||
raise # TODO: raise something specific
|
||||
except jwt.exceptions.DecodeError:
|
||||
raise
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class JWTCondition(ExecutionCallAccessControlCondition):
|
||||
"""
|
||||
A JWT condition can be satisfied by presenting a valid JWT token, which not only is
|
||||
required to be cryptographically verifiable, but also must fulfill certain additional
|
||||
restrictions defined in the condition.
|
||||
"""
|
||||
|
||||
EXECUTION_CALL_TYPE = JWTVerificationCall
|
||||
CONDITION_TYPE = ConditionType.JWT.value
|
||||
|
||||
class Schema(
|
||||
ExecutionCallAccessControlCondition.Schema, JWTVerificationCall.Schema
|
||||
):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.JWT.value), required=True
|
||||
)
|
||||
# jwt_token = fields.Str(required=True) # TODO: validate jwt encoded format
|
||||
# public_key = fields.Str(required=True) # required? maybe a valid PK certificate passed by requester?
|
||||
# expected_issuer = fields.Str(required=False)
|
||||
# subject = fields.Str(required=False)
|
||||
# expiration_window = fields.Int(
|
||||
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
|
||||
# )
|
||||
# issued_window = fields.Int(
|
||||
# strict=True, required=False, validate=validate.Range(min=0), allow_none=True
|
||||
# )
|
||||
# todo: kid (https://www.rfc-editor.org/rfc/rfc7515#section-4.1.4), x5u, etc
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return JWTCondition(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str = ConditionType.JWT.value,
|
||||
name: Optional[str] = None,
|
||||
jwt_token: Optional[str] = None,
|
||||
public_key: Optional[str] = None,
|
||||
expected_issuer: Optional[str] = None,
|
||||
# subject: Optional[str] = None,
|
||||
# expiration_window: Optional[int] = None,
|
||||
# issued_window: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
jwt_token=jwt_token,
|
||||
public_key=public_key,
|
||||
expected_issuer=expected_issuer,
|
||||
# subject=subject,
|
||||
# expiration=expiration_window,
|
||||
# issued_window=issued_window,
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
return_value_test=ReturnValueTest(
|
||||
comparator="==", value=True
|
||||
), # TODO: Workaround for now
|
||||
)
|
||||
|
||||
@property
|
||||
def jwt_token(self):
|
||||
return self.execution_call.jwt_token
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
return self.execution_call.public_key
|
||||
|
||||
@property
|
||||
def expected_issuer(self):
|
||||
return self.execution_call.expected_issuer
|
||||
|
||||
#
|
||||
def verify(self, **context) -> Tuple[bool, Any]:
|
||||
try:
|
||||
payload = self.execution_call.execute(**context)
|
||||
result = True # TODO: Additional condition checks
|
||||
except Exception: # TODO: specific exceptions
|
||||
payload = None
|
||||
result = False
|
||||
|
||||
return result, payload
|
|
@ -0,0 +1,66 @@
|
|||
import jwt
|
||||
import pytest
|
||||
from marshmallow import validates
|
||||
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.jwt import JWTCondition, JWTVerificationCall
|
||||
|
||||
TEST_ECDSA_PUBLIC_KEY = (
|
||||
"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEXHVxB7s5SR7I9cWwry"
|
||||
"/JkECIReka\nCwG3uOLCYbw5gVzn4dRmwMyYUJFcQWuFSfECRK+uQOOXD0YSEucBq0p5tA==\n-----END PUBLIC "
|
||||
"KEY-----\n "
|
||||
)
|
||||
TEST_JWT_TOKEN = (
|
||||
"eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
".eyJpYXQiOjE3MzM0MjQ3MTd9"
|
||||
".uc2Av6f4yibXRLtmCmvhbRiNfYTrkHPS3vAGHaamX1CQ4mQR8iGyE8X3TvseCclkgsbKBBKZG8nQXhA5hsXLRg"
|
||||
)
|
||||
|
||||
|
||||
class TestJWTVerificationCall(JWTVerificationCall):
|
||||
class Schema(JWTVerificationCall.Schema):
|
||||
@validates("jwt_token")
|
||||
def validate_jwt_token(self, value):
|
||||
pass
|
||||
|
||||
|
||||
def test_raw_jwt_decode():
|
||||
# Valid JWT
|
||||
jwt.decode(TEST_JWT_TOKEN, TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
|
||||
|
||||
# Invalid JWT
|
||||
with pytest.raises(jwt.exceptions.InvalidTokenError):
|
||||
jwt.decode(TEST_JWT_TOKEN[1:], TEST_ECDSA_PUBLIC_KEY, algorithms=["ES256"])
|
||||
|
||||
|
||||
def test_jwt_verification_call_invalid():
|
||||
message = r"Invalid value for JWT token; expected a context variable"
|
||||
with pytest.raises(ExecutionCall.InvalidExecutionCall, match=message):
|
||||
JWTVerificationCall(jwt_token=TEST_JWT_TOKEN, public_key=TEST_ECDSA_PUBLIC_KEY)
|
||||
|
||||
|
||||
def test_jwt_verification_call_invalid2():
|
||||
TestJWTVerificationCall(jwt_token=TEST_JWT_TOKEN, public_key=TEST_ECDSA_PUBLIC_KEY)
|
||||
|
||||
|
||||
def test_jwt_condition_initialization():
|
||||
condition = JWTCondition(
|
||||
jwt_token=":aContextVariableForJWTs",
|
||||
public_key=TEST_ECDSA_PUBLIC_KEY,
|
||||
)
|
||||
|
||||
assert condition.jwt_token == ":aContextVariableForJWTs"
|
||||
assert condition.public_key == TEST_ECDSA_PUBLIC_KEY
|
||||
assert condition.condition_type == JWTCondition.CONDITION_TYPE
|
||||
|
||||
|
||||
def test_jwt_condition_verify():
|
||||
condition = JWTCondition(
|
||||
jwt_token=":anotherContextVariableForJWTs",
|
||||
public_key=TEST_ECDSA_PUBLIC_KEY,
|
||||
)
|
||||
|
||||
context = {":anotherContextVariableForJWTs": TEST_JWT_TOKEN}
|
||||
success, result = condition.verify(**context)
|
||||
assert success
|
||||
assert result is not None
|
Loading…
Reference in New Issue