Some code reorg that leaves space for additional authentication types other than EVM.

pull/3510/head
derekpierre 2024-06-06 15:48:45 -04:00
parent 3b3263a57d
commit 59d42b7b2e
No known key found for this signature in database
6 changed files with 44 additions and 41 deletions

View File

@ -7,7 +7,7 @@ from eth_account.messages import HexBytes, encode_typed_data
from siwe import SiweMessage, VerificationError
class Auth:
class EvmAuth:
class AuthScheme(Enum):
EIP712 = "EIP712"
EIP4361 = "EIP4361"
@ -36,7 +36,7 @@ class Auth:
raise ValueError(f"Invalid authentication scheme: {scheme}")
class EIP712Auth(Auth):
class EIP712Auth(EvmAuth):
@classmethod
def authenticate(cls, data, signature, expected_address):
try:
@ -65,7 +65,7 @@ class EIP712Auth(Auth):
)
class EIP4361Auth(Auth):
class EIP4361Auth(EvmAuth):
FRESHNESS_IN_HOURS = 2
@classmethod
@ -85,12 +85,14 @@ class EIP4361Auth(Auth):
)
# enforce a freshness check
# TODO: "not-before" throws off the freshness timing; so skip if specified. Is this safe / what we want?
# TODO: "not-before" throws off the freshness timing; so skip if specified.
# Is this safe / what we want?
if not siwe_message.not_before:
issued_at = maya.MayaDT.from_iso8601(siwe_message.issued_at)
if maya.now() > issued_at.add(hours=cls.FRESHNESS_IN_HOURS):
raise cls.AuthenticationFailed(
f"EIP4361 message is stale; more than {cls.FRESHNESS_IN_HOURS} hours old (issued at {issued_at.iso8601()})"
f"EIP4361 message is stale; more than {cls.FRESHNESS_IN_HOURS} "
f"hours old (issued at {issued_at.iso8601()})"
)
if siwe_message.address != expected_address:

View File

@ -5,7 +5,7 @@ from typing import Any, List, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from nucypher.policy.conditions.auth import Auth
from nucypher.policy.conditions.auth.evm import EvmAuth
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
InvalidContextVariableData,
@ -21,8 +21,8 @@ CONTEXT_REGEX = re.compile(":[a-zA-Z_][a-zA-Z0-9_]*")
USER_ADDRESS_SCHEMES = {
USER_ADDRESS_CONTEXT: None, # any of the available auth types
USER_ADDRESS_EIP712_CONTEXT: Auth.AuthScheme.EIP712.value,
USER_ADDRESS_EIP4361_CONTEXT: Auth.AuthScheme.EIP4361.value,
USER_ADDRESS_EIP712_CONTEXT: EvmAuth.AuthScheme.EIP712.value,
USER_ADDRESS_EIP4361_CONTEXT: EvmAuth.AuthScheme.EIP4361.value,
}
@ -40,8 +40,8 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
{
"signature": "<signature>",
"address": "<address>",
"scheme": "EIP712" | "SIWE" | ...
"typeData": ...
"scheme": "EIP712" | "EIP4361" | ...
"typedData": ...
}
}
"""
@ -51,22 +51,22 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
expected_address = to_checksum_address(user_address_info["address"])
typed_data = user_address_info["typedData"]
scheme = user_address_info.get("scheme", Auth.AuthScheme.EIP712.value)
scheme = user_address_info.get("scheme", EvmAuth.AuthScheme.EIP712.value)
expected_scheme = USER_ADDRESS_SCHEMES[user_address_context_variable]
if expected_scheme and scheme != expected_scheme:
raise UnexpectedScheme(
f"Expected {expected_scheme} authentication scheme, but received {scheme}"
)
auth = Auth.from_scheme(scheme)
auth = EvmAuth.from_scheme(scheme)
auth.authenticate(
data=typed_data, signature=signature, expected_address=expected_address
)
except Auth.InvalidData as e:
except EvmAuth.InvalidData as e:
raise InvalidContextVariableData(
f"Invalid context variable data for '{user_address_context_variable}'; {e}"
)
except Auth.AuthenticationFailed as e:
except EvmAuth.AuthenticationFailed as e:
raise ContextVariableVerificationFailed(
f"Authentication failed for '{user_address_context_variable}'; {e}"
)

View File

@ -37,7 +37,7 @@ from nucypher.config.constants import TEMPORARY_DOMAIN_NAME
from nucypher.crypto.ferveo import dkg
from nucypher.crypto.keystore import Keystore
from nucypher.network.nodes import TEACHER_NODES
from nucypher.policy.conditions.auth import Auth
from nucypher.policy.conditions.auth.evm import EvmAuth
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.lingo import (
@ -654,13 +654,13 @@ def valid_user_address_auth_message(request):
auth_message_type = request.param
if auth_message_type is None:
# pick one at random
auth_message_type = random.choice(Auth.AuthScheme.values())
auth_message_type = random.choice(EvmAuth.AuthScheme.values())
if auth_message_type == Auth.AuthScheme.EIP712.value:
if auth_message_type == EvmAuth.AuthScheme.EIP712.value:
auth_message = {
"signature": "0x488a7acefdc6d098eedf73cdfd379777c0f4a4023a660d350d3bf309a51dd4251abaad9cdd11b71c400cfb4625c14ca142f72b39165bd980c8da1ea32892ff071c",
"address": "0x5ce9454909639D2D17A3F753ce7d93fa0b9aB12E",
"scheme": f"{Auth.AuthScheme.EIP712.value}",
"scheme": f"{EvmAuth.AuthScheme.EIP712.value}",
"typedData": {
"primaryType": "Wallet",
"types": {
@ -691,7 +691,7 @@ def valid_user_address_auth_message(request):
},
},
}
elif auth_message_type == Auth.AuthScheme.EIP4361.value:
elif auth_message_type == EvmAuth.AuthScheme.EIP4361.value:
signer = InMemorySigner()
siwe_message_data = {
"domain": "login.xyz",
@ -710,7 +710,7 @@ def valid_user_address_auth_message(request):
auth_message = {
"signature": f"{signature.hex()}",
"address": f"{signer.accounts[0]}",
"scheme": f"{Auth.AuthScheme.EIP4361.value}",
"scheme": f"{EvmAuth.AuthScheme.EIP4361.value}",
"typedData": f"{siwe_message}",
}
else:

View File

@ -4,7 +4,7 @@ import re
import pytest
from nucypher.policy.conditions.auth import Auth
from nucypher.policy.conditions.auth.evm import EvmAuth
from nucypher.policy.conditions.context import (
USER_ADDRESS_EIP712_CONTEXT,
USER_ADDRESS_EIP4361_CONTEXT,
@ -138,8 +138,8 @@ def test_user_address_context_invalid_typed_data(
USER_ADDRESS_EIP4361_CONTEXT,
],
[
Auth.AuthScheme.EIP4361.value,
Auth.AuthScheme.EIP712.value,
EvmAuth.AuthScheme.EIP4361.value,
EvmAuth.AuthScheme.EIP712.value,
],
)
),

View File

@ -3,23 +3,23 @@ import pytest
from siwe import SiweMessage
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.policy.conditions.auth import Auth, EIP712Auth, EIP4361Auth
from nucypher.policy.conditions.auth.evm import EIP712Auth, EIP4361Auth, EvmAuth
def test_auth_scheme():
for scheme in Auth.AuthScheme:
for scheme in EvmAuth.AuthScheme:
expected_scheme = (
EIP712Auth if scheme == Auth.AuthScheme.EIP712 else EIP4361Auth
EIP712Auth if scheme == EvmAuth.AuthScheme.EIP712 else EIP4361Auth
)
assert Auth.from_scheme(scheme=scheme.value) == expected_scheme
assert EvmAuth.from_scheme(scheme=scheme.value) == expected_scheme
# non-existent scheme
with pytest.raises(ValueError):
_ = Auth.from_scheme(scheme="rando")
_ = EvmAuth.from_scheme(scheme="rando")
@pytest.mark.parametrize(
"valid_user_address_auth_message", [Auth.AuthScheme.EIP712.value], indirect=True
"valid_user_address_auth_message", [EvmAuth.AuthScheme.EIP712.value], indirect=True
)
def test_authenticate_eip712(
valid_user_address_auth_message, get_random_checksum_address
@ -31,14 +31,14 @@ def test_authenticate_eip712(
# invalid data
invalid_data = dict(data) # make a copy
del invalid_data["domain"]
with pytest.raises(Auth.InvalidData):
with pytest.raises(EvmAuth.InvalidData):
EIP712Auth.authenticate(
data=invalid_data, signature=signature, expected_address=address
)
invalid_data = dict(data) # make a copy
del invalid_data["message"]
with pytest.raises(Auth.InvalidData):
with pytest.raises(EvmAuth.InvalidData):
EIP712Auth.authenticate(
data=invalid_data, signature=signature, expected_address=address
)
@ -48,20 +48,20 @@ def test_authenticate_eip712(
"0x93252ddff5f90584b27b5eef1915b23a8b01a703be56c8bf0660647c15cb75e9"
"1983bde9877eaad11da5a3ebc9b64957f1c182536931f9844d0c600f0c41293d1b"
)
with pytest.raises(Auth.AuthenticationFailed):
with pytest.raises(EvmAuth.AuthenticationFailed):
EIP712Auth.authenticate(
data=data, signature=incorrect_signature, expected_address=address
)
# invalid signature
invalid_signature = "0xdeadbeef"
with pytest.raises(Auth.InvalidData):
with pytest.raises(EvmAuth.InvalidData):
EIP712Auth.authenticate(
data=data, signature=invalid_signature, expected_address=address
)
# mismatch with expected address
with pytest.raises(Auth.AuthenticationFailed):
with pytest.raises(EvmAuth.AuthenticationFailed):
EIP712Auth.authenticate(
data=data,
signature=signature,
@ -97,7 +97,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
# invalid data
invalid_data = "just a regular old string"
with pytest.raises(Auth.InvalidData):
with pytest.raises(EvmAuth.InvalidData):
EIP4361Auth.authenticate(
data=invalid_data,
signature=valid_message_signature,
@ -110,7 +110,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
"1983bde9877eaad11da5a3ebc9b64957f1c182536931f9844d0c600f0c41293d1b"
)
with pytest.raises(
Auth.AuthenticationFailed,
EvmAuth.AuthenticationFailed,
match="EIP4361 verification failed - InvalidSignature",
):
EIP4361Auth.authenticate(
@ -122,7 +122,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
# invalid signature
invalid_signature = "0xdeadbeef"
with pytest.raises(
Auth.AuthenticationFailed,
EvmAuth.AuthenticationFailed,
match="EIP4361 verification failed - InvalidSignature",
):
EIP4361Auth.authenticate(
@ -133,7 +133,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
# mismatch with expected address
with pytest.raises(
Auth.AuthenticationFailed, match="does not match expected address"
EvmAuth.AuthenticationFailed, match="does not match expected address"
):
EIP4361Auth.authenticate(
data=valid_message,
@ -150,7 +150,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
stale_message_signature = signer.sign_message(
account=valid_address_for_signature, message=stale_message.encode()
)
with pytest.raises(Auth.AuthenticationFailed, match="EIP4361 message is stale"):
with pytest.raises(EvmAuth.AuthenticationFailed, match="EIP4361 message is stale"):
EIP4361Auth.authenticate(
stale_message, stale_message_signature.hex(), valid_address_for_signature
)
@ -185,7 +185,8 @@ def test_authenticate_eip4361(get_random_checksum_address):
message=not_stale_but_past_expiry_message.encode(),
)
with pytest.raises(
Auth.AuthenticationFailed, match="EIP4361 verification failed - ExpiredMessage"
EvmAuth.AuthenticationFailed,
match="EIP4361 verification failed - ExpiredMessage",
):
EIP4361Auth.authenticate(
not_stale_but_past_expiry_message,
@ -201,7 +202,7 @@ def test_authenticate_eip4361(get_random_checksum_address):
account=valid_address_for_signature, message=not_before_message.encode()
)
with pytest.raises(
Auth.AuthenticationFailed,
EvmAuth.AuthenticationFailed,
match="EIP4361 verification failed - NotYetValidMessage",
):
EIP4361Auth.authenticate(