mirror of https://github.com/nucypher/nucypher.git
Initial implementation of EvmAuth for EIP1271.
Update use of provider manager across EvmAuth initerface and subsequent tests.pull/3576/head
parent
dcd870d49e
commit
93ba45f48b
|
@ -1,16 +1,20 @@
|
|||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import maya
|
||||
from eth_account.account import Account
|
||||
from eth_account.messages import HexBytes, encode_typed_data
|
||||
from eth_typing import ChecksumAddress
|
||||
from siwe import SiweMessage, VerificationError
|
||||
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
|
||||
class EvmAuth:
|
||||
class AuthScheme(Enum):
|
||||
EIP712 = "EIP712"
|
||||
EIP4361 = "EIP4361"
|
||||
EIP1271 = "EIP1271"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> List[str]:
|
||||
|
@ -26,7 +30,13 @@ class EvmAuth:
|
|||
"""The message is too old."""
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -35,13 +45,21 @@ class EvmAuth:
|
|||
return EIP712Auth
|
||||
elif scheme == cls.AuthScheme.EIP4361.value:
|
||||
return EIP4361Auth
|
||||
elif scheme == cls.AuthScheme.EIP1271.value:
|
||||
return EIP1271Auth
|
||||
|
||||
raise ValueError(f"Invalid authentication scheme: {scheme}")
|
||||
|
||||
|
||||
class EIP712Auth(EvmAuth):
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
try:
|
||||
# convert hex data for byte fields - bytes are expected by underlying library
|
||||
# 1. salt
|
||||
|
@ -72,7 +90,13 @@ class EIP4361Auth(EvmAuth):
|
|||
FRESHNESS_IN_HOURS = 2
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
try:
|
||||
siwe_message = SiweMessage.from_message(message=data)
|
||||
except Exception as e:
|
||||
|
@ -106,3 +130,70 @@ class EIP4361Auth(EvmAuth):
|
|||
raise cls.AuthenticationFailed(
|
||||
f"Invalid EIP4361 signature; signature not valid for expected address, {expected_address}"
|
||||
)
|
||||
|
||||
|
||||
class EIP1271Auth(EvmAuth):
|
||||
EIP1271_ABI = """[
|
||||
{
|
||||
"inputs":[
|
||||
{
|
||||
"internalType":"bytes32",
|
||||
"name":"_hash",
|
||||
"type":"bytes32"
|
||||
},
|
||||
{
|
||||
"internalType":"bytes",
|
||||
"name":"_signature",
|
||||
"type":"bytes"
|
||||
}
|
||||
],
|
||||
"name":"isValidSignature",
|
||||
"outputs":[
|
||||
{
|
||||
"internalType":"bytes4",
|
||||
"name":"",
|
||||
"type":"bytes4"
|
||||
}
|
||||
],
|
||||
"stateMutability":"view",
|
||||
"type":"function"
|
||||
}
|
||||
]"""
|
||||
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
|
||||
|
||||
@classmethod
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: ChecksumAddress,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
result = None
|
||||
try:
|
||||
data_hash = bytes(HexBytes(data["dataHash"]))
|
||||
chain_id = data["chain_id"]
|
||||
signature_bytes = bytes(HexBytes(signature))
|
||||
|
||||
w3_instances = providers.web3_endpoints(chain_id=chain_id)
|
||||
|
||||
for w3 in w3_instances:
|
||||
eip1271_contract = w3.eth.contract(
|
||||
address=expected_address, abi=cls.EIP1271_ABI
|
||||
)
|
||||
result = eip1271_contract.functions.isValidSignature(
|
||||
data_hash,
|
||||
signature_bytes,
|
||||
).call()
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
# data could not be processed
|
||||
raise cls.InvalidData(
|
||||
f"Invalid EIP1271 message: {str(e) or e.__class__.__name__}"
|
||||
)
|
||||
|
||||
if result != cls.MAGIC_VALUE_BYTES:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
|
@ -11,6 +11,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidContextVariableData,
|
||||
RequiredContextVariable,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
USER_ADDRESS_CONTEXT = ":userAddress"
|
||||
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT = ":userAddressExternalEIP4361"
|
||||
|
@ -19,7 +20,7 @@ CONTEXT_PREFIX = ":"
|
|||
CONTEXT_REGEX = re.compile(":[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
|
||||
USER_ADDRESS_SCHEMES = {
|
||||
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP712) for now; eventually EIP712 will be deprecated
|
||||
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP1271, EIP712) for now; eventually EIP712 will be deprecated
|
||||
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT: EvmAuth.AuthScheme.EIP4361.value,
|
||||
}
|
||||
|
||||
|
@ -28,7 +29,11 @@ class UnexpectedScheme(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def _resolve_user_address(user_address_context_variable, **context) -> ChecksumAddress:
|
||||
def _resolve_user_address(
|
||||
user_address_context_variable: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
) -> ChecksumAddress:
|
||||
"""
|
||||
Recovers a checksum address from a signed message.
|
||||
|
||||
|
@ -38,7 +43,7 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
|
|||
{
|
||||
"signature": "<signature>",
|
||||
"address": "<address>",
|
||||
"scheme": "EIP4361" | ...
|
||||
"scheme": "EIP4361" | "EIP1271" | ...
|
||||
"typedData": ...
|
||||
}
|
||||
}
|
||||
|
@ -59,7 +64,10 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
|
|||
|
||||
auth = EvmAuth.from_scheme(scheme)
|
||||
auth.authenticate(
|
||||
data=typed_data, signature=signature, expected_address=expected_address
|
||||
data=typed_data,
|
||||
signature=signature,
|
||||
expected_address=expected_address,
|
||||
providers=providers,
|
||||
)
|
||||
except EvmAuth.InvalidData as e:
|
||||
raise InvalidContextVariableData(
|
||||
|
@ -98,7 +106,11 @@ def string_contains_context_variable(variable: str) -> bool:
|
|||
return bool(matches)
|
||||
|
||||
|
||||
def get_context_value(context_variable: str, **context) -> Any:
|
||||
def get_context_value(
|
||||
context_variable: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
) -> Any:
|
||||
try:
|
||||
# DIRECTIVES are special context vars that will pre-processed by ursula
|
||||
func = _DIRECTIVES[context_variable]
|
||||
|
@ -111,33 +123,40 @@ def get_context_value(context_variable: str, **context) -> Any:
|
|||
f'No value provided for unrecognized context variable "{context_variable}"'
|
||||
)
|
||||
else:
|
||||
value = func(**context) # required inputs here
|
||||
value = func(providers=providers, **context) # required inputs here
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def resolve_any_context_variables(
|
||||
param: Union[Any, List[Any], Dict[Any, Any]], **context
|
||||
param: Union[Any, List[Any], Dict[Any, Any]],
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
):
|
||||
if isinstance(param, list):
|
||||
return [resolve_any_context_variables(item, **context) for item in param]
|
||||
return [
|
||||
resolve_any_context_variables(item, providers, **context) for item in param
|
||||
]
|
||||
elif isinstance(param, dict):
|
||||
return {
|
||||
k: resolve_any_context_variables(v, **context) for k, v in param.items()
|
||||
k: resolve_any_context_variables(v, providers, **context)
|
||||
for k, v in param.items()
|
||||
}
|
||||
elif isinstance(param, str):
|
||||
# either it is a context variable OR contains a context variable within it
|
||||
# TODO separating the two cases for now out of concern of regex searching
|
||||
# within strings (case 2)
|
||||
if is_context_variable(param):
|
||||
return get_context_value(context_variable=param, **context)
|
||||
return get_context_value(
|
||||
context_variable=param, providers=providers, **context
|
||||
)
|
||||
else:
|
||||
matches = re.findall(CONTEXT_REGEX, param)
|
||||
for context_var in matches:
|
||||
# checking out of concern for faulty regex search within string
|
||||
if context_var in context:
|
||||
resolved_var = get_context_value(
|
||||
context_variable=context_var, **context
|
||||
context_variable=context_var, providers=providers, **context
|
||||
)
|
||||
param = param.replace(context_var, str(resolved_var))
|
||||
return param
|
||||
|
|
|
@ -106,7 +106,7 @@ class RPCCall(ExecutionCall):
|
|||
resolved_parameters = []
|
||||
if self.parameters:
|
||||
resolved_parameters = resolve_any_context_variables(
|
||||
self.parameters, **context
|
||||
param=self.parameters, providers=providers, **context
|
||||
)
|
||||
|
||||
endpoints = providers.web3_endpoints(self.chain)
|
||||
|
@ -216,7 +216,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
self, providers: ConditionProviderManager, **context
|
||||
) -> Tuple[bool, Any]:
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**context
|
||||
providers=providers, **context
|
||||
)
|
||||
return_value_test = self._align_comparator_value_with_abi(
|
||||
resolved_return_value_test
|
||||
|
|
|
@ -610,8 +610,10 @@ class ReturnValueTest:
|
|||
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
|
||||
return result
|
||||
|
||||
def with_resolved_context(self, **context):
|
||||
value = resolve_any_context_variables(self.value, **context)
|
||||
def with_resolved_context(
|
||||
self, providers: Optional[ConditionProviderManager] = None, **context
|
||||
):
|
||||
value = resolve_any_context_variables(self.value, providers, **context)
|
||||
return ReturnValueTest(self.comparator, value=value, index=self.index)
|
||||
|
||||
|
||||
|
|
|
@ -3,14 +3,23 @@ import pytest
|
|||
from siwe import SiweMessage
|
||||
|
||||
from nucypher.blockchain.eth.signers import InMemorySigner
|
||||
from nucypher.policy.conditions.auth.evm import EIP712Auth, EIP4361Auth, EvmAuth
|
||||
from nucypher.policy.conditions.auth.evm import (
|
||||
EIP712Auth,
|
||||
EIP1271Auth,
|
||||
EIP4361Auth,
|
||||
EvmAuth,
|
||||
)
|
||||
|
||||
|
||||
def test_auth_scheme():
|
||||
expected_schemes = {
|
||||
EvmAuth.AuthScheme.EIP712: EIP712Auth,
|
||||
EvmAuth.AuthScheme.EIP4361: EIP4361Auth,
|
||||
EvmAuth.AuthScheme.EIP1271: EIP1271Auth,
|
||||
}
|
||||
|
||||
for scheme in EvmAuth.AuthScheme:
|
||||
expected_scheme = (
|
||||
EIP712Auth if scheme == EvmAuth.AuthScheme.EIP712 else EIP4361Auth
|
||||
)
|
||||
expected_scheme = expected_schemes.get(scheme)
|
||||
assert EvmAuth.from_scheme(scheme=scheme.value) == expected_scheme
|
||||
|
||||
# non-existent scheme
|
||||
|
|
Loading…
Reference in New Issue