Initial implementation of EvmAuth for EIP1271.

Update use of provider manager across EvmAuth initerface and subsequent tests.
pull/3576/head
derekpierre 2025-01-24 17:19:01 -05:00
parent dcd870d49e
commit 93ba45f48b
No known key found for this signature in database
5 changed files with 145 additions and 24 deletions

View File

@ -1,16 +1,20 @@
from enum import Enum
from typing import List
from typing import List, Optional
import maya
from eth_account.account import Account
from eth_account.messages import HexBytes, encode_typed_data
from eth_typing import ChecksumAddress
from siwe import SiweMessage, VerificationError
from nucypher.policy.conditions.utils import ConditionProviderManager
class EvmAuth:
class AuthScheme(Enum):
EIP712 = "EIP712"
EIP4361 = "EIP4361"
EIP1271 = "EIP1271"
@classmethod
def values(cls) -> List[str]:
@ -26,7 +30,13 @@ class EvmAuth:
"""The message is too old."""
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
raise NotImplementedError
@classmethod
@ -35,13 +45,21 @@ class EvmAuth:
return EIP712Auth
elif scheme == cls.AuthScheme.EIP4361.value:
return EIP4361Auth
elif scheme == cls.AuthScheme.EIP1271.value:
return EIP1271Auth
raise ValueError(f"Invalid authentication scheme: {scheme}")
class EIP712Auth(EvmAuth):
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
try:
# convert hex data for byte fields - bytes are expected by underlying library
# 1. salt
@ -72,7 +90,13 @@ class EIP4361Auth(EvmAuth):
FRESHNESS_IN_HOURS = 2
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
try:
siwe_message = SiweMessage.from_message(message=data)
except Exception as e:
@ -106,3 +130,70 @@ class EIP4361Auth(EvmAuth):
raise cls.AuthenticationFailed(
f"Invalid EIP4361 signature; signature not valid for expected address, {expected_address}"
)
class EIP1271Auth(EvmAuth):
EIP1271_ABI = """[
{
"inputs":[
{
"internalType":"bytes32",
"name":"_hash",
"type":"bytes32"
},
{
"internalType":"bytes",
"name":"_signature",
"type":"bytes"
}
],
"name":"isValidSignature",
"outputs":[
{
"internalType":"bytes4",
"name":"",
"type":"bytes4"
}
],
"stateMutability":"view",
"type":"function"
}
]"""
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
@classmethod
def authenticate(
cls,
data,
signature: str,
expected_address: ChecksumAddress,
providers: Optional[ConditionProviderManager] = None,
):
result = None
try:
data_hash = bytes(HexBytes(data["dataHash"]))
chain_id = data["chain_id"]
signature_bytes = bytes(HexBytes(signature))
w3_instances = providers.web3_endpoints(chain_id=chain_id)
for w3 in w3_instances:
eip1271_contract = w3.eth.contract(
address=expected_address, abi=cls.EIP1271_ABI
)
result = eip1271_contract.functions.isValidSignature(
data_hash,
signature_bytes,
).call()
break
except Exception as e:
# data could not be processed
raise cls.InvalidData(
f"Invalid EIP1271 message: {str(e) or e.__class__.__name__}"
)
if result != cls.MAGIC_VALUE_BYTES:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
)

View File

@ -1,6 +1,6 @@
import re
from functools import partial
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
@ -11,6 +11,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidContextVariableData,
RequiredContextVariable,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
USER_ADDRESS_CONTEXT = ":userAddress"
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT = ":userAddressExternalEIP4361"
@ -19,7 +20,7 @@ CONTEXT_PREFIX = ":"
CONTEXT_REGEX = re.compile(":[a-zA-Z_][a-zA-Z0-9_]*")
USER_ADDRESS_SCHEMES = {
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP712) for now; eventually EIP712 will be deprecated
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP1271, EIP712) for now; eventually EIP712 will be deprecated
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT: EvmAuth.AuthScheme.EIP4361.value,
}
@ -28,7 +29,11 @@ class UnexpectedScheme(Exception):
pass
def _resolve_user_address(user_address_context_variable, **context) -> ChecksumAddress:
def _resolve_user_address(
user_address_context_variable: str,
providers: Optional[ConditionProviderManager] = None,
**context,
) -> ChecksumAddress:
"""
Recovers a checksum address from a signed message.
@ -38,7 +43,7 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
{
"signature": "<signature>",
"address": "<address>",
"scheme": "EIP4361" | ...
"scheme": "EIP4361" | "EIP1271" | ...
"typedData": ...
}
}
@ -59,7 +64,10 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
auth = EvmAuth.from_scheme(scheme)
auth.authenticate(
data=typed_data, signature=signature, expected_address=expected_address
data=typed_data,
signature=signature,
expected_address=expected_address,
providers=providers,
)
except EvmAuth.InvalidData as e:
raise InvalidContextVariableData(
@ -98,7 +106,11 @@ def string_contains_context_variable(variable: str) -> bool:
return bool(matches)
def get_context_value(context_variable: str, **context) -> Any:
def get_context_value(
context_variable: str,
providers: Optional[ConditionProviderManager] = None,
**context,
) -> Any:
try:
# DIRECTIVES are special context vars that will pre-processed by ursula
func = _DIRECTIVES[context_variable]
@ -111,33 +123,40 @@ def get_context_value(context_variable: str, **context) -> Any:
f'No value provided for unrecognized context variable "{context_variable}"'
)
else:
value = func(**context) # required inputs here
value = func(providers=providers, **context) # required inputs here
return value
def resolve_any_context_variables(
param: Union[Any, List[Any], Dict[Any, Any]], **context
param: Union[Any, List[Any], Dict[Any, Any]],
providers: Optional[ConditionProviderManager] = None,
**context,
):
if isinstance(param, list):
return [resolve_any_context_variables(item, **context) for item in param]
return [
resolve_any_context_variables(item, providers, **context) for item in param
]
elif isinstance(param, dict):
return {
k: resolve_any_context_variables(v, **context) for k, v in param.items()
k: resolve_any_context_variables(v, providers, **context)
for k, v in param.items()
}
elif isinstance(param, str):
# either it is a context variable OR contains a context variable within it
# TODO separating the two cases for now out of concern of regex searching
# within strings (case 2)
if is_context_variable(param):
return get_context_value(context_variable=param, **context)
return get_context_value(
context_variable=param, providers=providers, **context
)
else:
matches = re.findall(CONTEXT_REGEX, param)
for context_var in matches:
# checking out of concern for faulty regex search within string
if context_var in context:
resolved_var = get_context_value(
context_variable=context_var, **context
context_variable=context_var, providers=providers, **context
)
param = param.replace(context_var, str(resolved_var))
return param

View File

@ -106,7 +106,7 @@ class RPCCall(ExecutionCall):
resolved_parameters = []
if self.parameters:
resolved_parameters = resolve_any_context_variables(
self.parameters, **context
param=self.parameters, providers=providers, **context
)
endpoints = providers.web3_endpoints(self.chain)
@ -216,7 +216,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
self, providers: ConditionProviderManager, **context
) -> Tuple[bool, Any]:
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
providers=providers, **context
)
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test

View File

@ -610,8 +610,10 @@ class ReturnValueTest:
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
return result
def with_resolved_context(self, **context):
value = resolve_any_context_variables(self.value, **context)
def with_resolved_context(
self, providers: Optional[ConditionProviderManager] = None, **context
):
value = resolve_any_context_variables(self.value, providers, **context)
return ReturnValueTest(self.comparator, value=value, index=self.index)

View File

@ -3,14 +3,23 @@ import pytest
from siwe import SiweMessage
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.policy.conditions.auth.evm import EIP712Auth, EIP4361Auth, EvmAuth
from nucypher.policy.conditions.auth.evm import (
EIP712Auth,
EIP1271Auth,
EIP4361Auth,
EvmAuth,
)
def test_auth_scheme():
expected_schemes = {
EvmAuth.AuthScheme.EIP712: EIP712Auth,
EvmAuth.AuthScheme.EIP4361: EIP4361Auth,
EvmAuth.AuthScheme.EIP1271: EIP1271Auth,
}
for scheme in EvmAuth.AuthScheme:
expected_scheme = (
EIP712Auth if scheme == EvmAuth.AuthScheme.EIP712 else EIP4361Auth
)
expected_scheme = expected_schemes.get(scheme)
assert EvmAuth.from_scheme(scheme=scheme.value) == expected_scheme
# non-existent scheme