mirror of https://github.com/nucypher/nucypher.git
commit
07e13b9930
|
@ -0,0 +1 @@
|
|||
Add support for EIP1271 signature verification for smart contract wallets.
|
|
@ -4,7 +4,7 @@ import time
|
|||
import traceback
|
||||
from collections import defaultdict
|
||||
from decimal import Decimal
|
||||
from typing import DefaultDict, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import maya
|
||||
from atxm.exceptions import InsufficientFunds
|
||||
|
@ -65,7 +65,10 @@ from nucypher.crypto.powers import (
|
|||
TransactingPower,
|
||||
)
|
||||
from nucypher.datastore.dkg import DKGStorage
|
||||
from nucypher.policy.conditions.utils import evaluate_condition_lingo
|
||||
from nucypher.policy.conditions.utils import (
|
||||
ConditionProviderManager,
|
||||
evaluate_condition_lingo,
|
||||
)
|
||||
from nucypher.policy.payment import ContractPayment
|
||||
from nucypher.types import PhaseId
|
||||
from nucypher.utilities.emitters import StdoutEmitter
|
||||
|
@ -247,7 +250,7 @@ class Operator(BaseActor):
|
|||
ThresholdRequestDecryptingPower
|
||||
) # used for secure decryption request channel
|
||||
|
||||
self.condition_providers = self.connect_condition_providers(
|
||||
self.condition_provider_manager = self.get_condition_provider_manager(
|
||||
condition_blockchain_endpoints
|
||||
)
|
||||
|
||||
|
@ -269,9 +272,9 @@ class Operator(BaseActor):
|
|||
provider = HTTPProvider(endpoint_uri=uri)
|
||||
return provider
|
||||
|
||||
def connect_condition_providers(
|
||||
def get_condition_provider_manager(
|
||||
self, operator_configured_endpoints: Dict[int, List[str]]
|
||||
) -> DefaultDict[int, List[HTTPProvider]]:
|
||||
) -> ConditionProviderManager:
|
||||
|
||||
# check that we have mandatory user configured endpoints
|
||||
mandatory_configured_chains = {
|
||||
|
@ -336,7 +339,7 @@ class Operator(BaseActor):
|
|||
f"checking on chain IDs {providers.keys()}"
|
||||
)
|
||||
|
||||
return providers
|
||||
return ConditionProviderManager(providers=providers)
|
||||
|
||||
def _resolve_ritual(self, ritual_id: int) -> Coordinator.Ritual:
|
||||
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
|
||||
|
@ -845,7 +848,7 @@ class Operator(BaseActor):
|
|||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
context=context,
|
||||
providers=self.condition_providers,
|
||||
providers=self.condition_provider_manager,
|
||||
)
|
||||
|
||||
def _verify_decryption_request_authorization(
|
||||
|
|
|
@ -154,8 +154,9 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
"""
|
||||
# TODO: When non-evm chains are supported, bump the version.
|
||||
# this can return a list of chain names or other verifiable identifiers.
|
||||
|
||||
payload = {"version": 1.0, "evm": list(this_node.condition_providers)}
|
||||
providers = this_node.condition_provider_manager.providers
|
||||
sorted_chain_ids = sorted(list(providers))
|
||||
payload = {"version": 1.0, "evm": sorted_chain_ids}
|
||||
return Response(json.dumps(payload), mimetype="application/json")
|
||||
|
||||
@rest_app.route('/decrypt', methods=["POST"])
|
||||
|
@ -260,7 +261,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
try:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers=this_node.condition_providers,
|
||||
providers=this_node.condition_provider_manager,
|
||||
context=context,
|
||||
)
|
||||
except ConditionEvalError as error:
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import maya
|
||||
from eth_account.account import Account
|
||||
from eth_account.messages import HexBytes, encode_typed_data
|
||||
from eth_typing import ChecksumAddress
|
||||
from siwe import SiweMessage, VerificationError
|
||||
|
||||
from nucypher.policy.conditions.exceptions import NoConnectionToChain
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
||||
class EvmAuth:
|
||||
class AuthScheme(Enum):
|
||||
EIP712 = "EIP712"
|
||||
EIP4361 = "EIP4361"
|
||||
EIP1271 = "EIP1271"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> List[str]:
|
||||
|
@ -26,7 +32,13 @@ class EvmAuth:
|
|||
"""The message is too old."""
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -35,13 +47,21 @@ class EvmAuth:
|
|||
return EIP712Auth
|
||||
elif scheme == cls.AuthScheme.EIP4361.value:
|
||||
return EIP4361Auth
|
||||
elif scheme == cls.AuthScheme.EIP1271.value:
|
||||
return EIP1271Auth
|
||||
|
||||
raise ValueError(f"Invalid authentication scheme: {scheme}")
|
||||
|
||||
|
||||
class EIP712Auth(EvmAuth):
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
try:
|
||||
# convert hex data for byte fields - bytes are expected by underlying library
|
||||
# 1. salt
|
||||
|
@ -72,7 +92,13 @@ class EIP4361Auth(EvmAuth):
|
|||
FRESHNESS_IN_HOURS = 2
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls, data, signature, expected_address):
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
try:
|
||||
siwe_message = SiweMessage.from_message(message=data)
|
||||
except Exception as e:
|
||||
|
@ -106,3 +132,113 @@ class EIP4361Auth(EvmAuth):
|
|||
raise cls.AuthenticationFailed(
|
||||
f"Invalid EIP4361 signature; signature not valid for expected address, {expected_address}"
|
||||
)
|
||||
|
||||
|
||||
class EIP1271Auth(EvmAuth):
|
||||
EIP1271_ABI = """[
|
||||
{
|
||||
"inputs":[
|
||||
{
|
||||
"internalType":"bytes32",
|
||||
"name":"_hash",
|
||||
"type":"bytes32"
|
||||
},
|
||||
{
|
||||
"internalType":"bytes",
|
||||
"name":"_signature",
|
||||
"type":"bytes"
|
||||
}
|
||||
],
|
||||
"name":"isValidSignature",
|
||||
"outputs":[
|
||||
{
|
||||
"internalType":"bytes4",
|
||||
"name":"",
|
||||
"type":"bytes4"
|
||||
}
|
||||
],
|
||||
"stateMutability":"view",
|
||||
"type":"function"
|
||||
}
|
||||
]"""
|
||||
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
|
||||
LOG = Logger("EIP1271Auth")
|
||||
|
||||
@classmethod
|
||||
def _extract_typed_data(cls, data):
|
||||
try:
|
||||
data_hash = bytes(HexBytes(data["dataHash"]))
|
||||
chain = data["chain"]
|
||||
return data_hash, chain
|
||||
except Exception as e:
|
||||
# data could not be processed
|
||||
raise cls.InvalidData(
|
||||
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_auth_data(
|
||||
cls, data_hash, signature_bytes, expected_address, chain, providers
|
||||
):
|
||||
web3_endpoints = providers.web3_endpoints(chain_id=chain)
|
||||
last_error = None
|
||||
for web3_instance in web3_endpoints:
|
||||
try:
|
||||
# Interact with the EIP1271 contract
|
||||
eip1271_contract = web3_instance.eth.contract(
|
||||
address=expected_address, abi=cls.EIP1271_ABI
|
||||
)
|
||||
result = eip1271_contract.functions.isValidSignature(
|
||||
data_hash,
|
||||
signature_bytes,
|
||||
).call()
|
||||
if result == cls.MAGIC_VALUE_BYTES:
|
||||
return # Successful authentication
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = f"EIP1271 contract call failed ({expected_address}): {e}"
|
||||
cls.LOG.warn(f"{last_error}; attempting next provider")
|
||||
else:
|
||||
# If all providers fail
|
||||
if last_error:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; {last_error}"
|
||||
)
|
||||
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def authenticate(
|
||||
cls,
|
||||
data,
|
||||
signature: str,
|
||||
expected_address: ChecksumAddress,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
if not providers:
|
||||
# should never happen
|
||||
raise cls.AuthenticationFailed(
|
||||
"EIP1271 verification failed; no endpoints provided"
|
||||
)
|
||||
|
||||
# Extract and validate input data
|
||||
data_hash, chain = cls._extract_typed_data(data)
|
||||
|
||||
# Validate the signature
|
||||
signature_bytes = bytes(HexBytes(signature))
|
||||
try:
|
||||
cls._validate_auth_data(
|
||||
data_hash, signature_bytes, expected_address, chain, providers
|
||||
)
|
||||
except NoConnectionToChain:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; No connection to chain ID {chain}"
|
||||
)
|
||||
except cls.AuthenticationFailed:
|
||||
raise
|
||||
except Exception as e:
|
||||
# catch all
|
||||
raise cls.AuthenticationFailed(f"EIP1271 verification failed; {e}")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
|
@ -11,6 +11,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidContextVariableData,
|
||||
RequiredContextVariable,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
USER_ADDRESS_CONTEXT = ":userAddress"
|
||||
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT = ":userAddressExternalEIP4361"
|
||||
|
@ -19,7 +20,7 @@ CONTEXT_PREFIX = ":"
|
|||
CONTEXT_REGEX = re.compile(":[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
|
||||
USER_ADDRESS_SCHEMES = {
|
||||
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP712) for now; eventually EIP712 will be deprecated
|
||||
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP1271, EIP712) for now; eventually EIP712 will be deprecated
|
||||
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT: EvmAuth.AuthScheme.EIP4361.value,
|
||||
}
|
||||
|
||||
|
@ -28,7 +29,11 @@ class UnexpectedScheme(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def _resolve_user_address(user_address_context_variable, **context) -> ChecksumAddress:
|
||||
def _resolve_user_address(
|
||||
user_address_context_variable: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
) -> ChecksumAddress:
|
||||
"""
|
||||
Recovers a checksum address from a signed message.
|
||||
|
||||
|
@ -38,7 +43,7 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
|
|||
{
|
||||
"signature": "<signature>",
|
||||
"address": "<address>",
|
||||
"scheme": "EIP4361" | ...
|
||||
"scheme": "EIP4361" | "EIP1271" | ...
|
||||
"typedData": ...
|
||||
}
|
||||
}
|
||||
|
@ -59,7 +64,10 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
|
|||
|
||||
auth = EvmAuth.from_scheme(scheme)
|
||||
auth.authenticate(
|
||||
data=typed_data, signature=signature, expected_address=expected_address
|
||||
data=typed_data,
|
||||
signature=signature,
|
||||
expected_address=expected_address,
|
||||
providers=providers,
|
||||
)
|
||||
except EvmAuth.InvalidData as e:
|
||||
raise InvalidContextVariableData(
|
||||
|
@ -98,7 +106,11 @@ def string_contains_context_variable(variable: str) -> bool:
|
|||
return bool(matches)
|
||||
|
||||
|
||||
def get_context_value(context_variable: str, **context) -> Any:
|
||||
def get_context_value(
|
||||
context_variable: str,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
) -> Any:
|
||||
try:
|
||||
# DIRECTIVES are special context vars that will pre-processed by ursula
|
||||
func = _DIRECTIVES[context_variable]
|
||||
|
@ -111,33 +123,40 @@ def get_context_value(context_variable: str, **context) -> Any:
|
|||
f'No value provided for unrecognized context variable "{context_variable}"'
|
||||
)
|
||||
else:
|
||||
value = func(**context) # required inputs here
|
||||
value = func(providers=providers, **context) # required inputs here
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def resolve_any_context_variables(
|
||||
param: Union[Any, List[Any], Dict[Any, Any]], **context
|
||||
param: Union[Any, List[Any], Dict[Any, Any]],
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
**context,
|
||||
):
|
||||
if isinstance(param, list):
|
||||
return [resolve_any_context_variables(item, **context) for item in param]
|
||||
return [
|
||||
resolve_any_context_variables(item, providers, **context) for item in param
|
||||
]
|
||||
elif isinstance(param, dict):
|
||||
return {
|
||||
k: resolve_any_context_variables(v, **context) for k, v in param.items()
|
||||
k: resolve_any_context_variables(v, providers, **context)
|
||||
for k, v in param.items()
|
||||
}
|
||||
elif isinstance(param, str):
|
||||
# either it is a context variable OR contains a context variable within it
|
||||
# TODO separating the two cases for now out of concern of regex searching
|
||||
# within strings (case 2)
|
||||
if is_context_variable(param):
|
||||
return get_context_value(context_variable=param, **context)
|
||||
return get_context_value(
|
||||
context_variable=param, providers=providers, **context
|
||||
)
|
||||
else:
|
||||
matches = re.findall(CONTEXT_REGEX, param)
|
||||
for context_var in matches:
|
||||
# checking out of concern for faulty regex search within string
|
||||
if context_var in context:
|
||||
resolved_var = get_context_value(
|
||||
context_variable=context_var, **context
|
||||
context_variable=context_var, providers=providers, **context
|
||||
)
|
||||
param = param.replace(context_var, str(resolved_var))
|
||||
return param
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
@ -20,9 +17,7 @@ from marshmallow import (
|
|||
)
|
||||
from marshmallow.validate import OneOf
|
||||
from typing_extensions import override
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.middleware import geth_poa_middleware
|
||||
from web3.providers import BaseProvider
|
||||
from web3 import Web3
|
||||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
|
||||
|
@ -34,7 +29,6 @@ from nucypher.policy.conditions.context import (
|
|||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
)
|
||||
|
@ -43,7 +37,10 @@ from nucypher.policy.conditions.lingo import (
|
|||
ExecutionCallAccessControlCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import camel_case_to_snake
|
||||
from nucypher.policy.conditions.utils import (
|
||||
ConditionProviderManager,
|
||||
camel_case_to_snake,
|
||||
)
|
||||
from nucypher.policy.conditions.validation import (
|
||||
align_comparator_value_with_abi,
|
||||
get_unbound_contract_function,
|
||||
|
@ -105,56 +102,17 @@ class RPCCall(ExecutionCall):
|
|||
) # bind contract function (only exposes the eth API)
|
||||
return rpc_function
|
||||
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self, w3: Web3) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise NoConnectionToChain(
|
||||
chain=self.chain,
|
||||
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}",
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(w3)
|
||||
return w3
|
||||
|
||||
def _next_endpoint(
|
||||
self, providers: Dict[int, Set[HTTPProvider]]
|
||||
) -> Iterator[HTTPProvider]:
|
||||
"""Yields the next web3 provider to try for a given chain ID"""
|
||||
rpc_providers = providers.get(self.chain, None)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=self.chain)
|
||||
|
||||
for provider in rpc_providers:
|
||||
# Someday, we might make this whole function async, and then we can knock on
|
||||
# each endpoint here to see if it's alive and only yield it if it is.
|
||||
yield provider
|
||||
|
||||
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
|
||||
def execute(self, providers: ConditionProviderManager, **context) -> Any:
|
||||
resolved_parameters = []
|
||||
if self.parameters:
|
||||
resolved_parameters = resolve_any_context_variables(
|
||||
self.parameters, **context
|
||||
param=self.parameters, providers=providers, **context
|
||||
)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
endpoints = providers.web3_endpoints(self.chain)
|
||||
|
||||
latest_error = ""
|
||||
for provider in endpoints:
|
||||
w3 = self._configure_provider(provider)
|
||||
for w3 in endpoints:
|
||||
try:
|
||||
result = self._execute(w3, resolved_parameters)
|
||||
break
|
||||
|
@ -255,10 +213,10 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
return return_value_test
|
||||
|
||||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
self, providers: ConditionProviderManager, **context
|
||||
) -> Tuple[bool, Any]:
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**context
|
||||
providers=providers, **context
|
||||
)
|
||||
return_value_test = self._align_comparator_value_with_abi(
|
||||
resolved_return_value_test
|
||||
|
|
|
@ -13,6 +13,19 @@ class NoConnectionToChain(RuntimeError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class InvalidConnectionToChain(RuntimeError):
|
||||
"""Raised when a node does not have a valid provider for a chain."""
|
||||
|
||||
def __init__(self, expected_chain: int, actual_chain: int, message: str = None):
|
||||
self.expected_chain = expected_chain
|
||||
self.actual_chain = actual_chain
|
||||
message = (
|
||||
message
|
||||
or f"Invalid blockchain connection; expected chain ID {expected_chain}, but detected {actual_chain}"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ReturnValueEvaluationError(Exception):
|
||||
"""Issue with Return Value and Key"""
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import operator as pyoperator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from marshmallow import (
|
||||
|
@ -19,7 +19,6 @@ from marshmallow import (
|
|||
)
|
||||
from marshmallow.validate import OneOf, Range
|
||||
from packaging.version import parse as parse_version
|
||||
from web3 import HTTPProvider
|
||||
|
||||
from nucypher.policy.conditions.base import (
|
||||
AccessControlCondition,
|
||||
|
@ -37,7 +36,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.policy.conditions.types import ConditionDict, Lingo
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema, ConditionProviderManager
|
||||
|
||||
|
||||
class _ConditionField(fields.Dict):
|
||||
|
@ -339,7 +338,7 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
|
|||
# TODO - think about not dereferencing context but using a dict;
|
||||
# may allows more freedom for params
|
||||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
self, providers: ConditionProviderManager, **context
|
||||
) -> Tuple[bool, Any]:
|
||||
values = []
|
||||
latest_success = False
|
||||
|
@ -611,8 +610,10 @@ class ReturnValueTest:
|
|||
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
|
||||
return result
|
||||
|
||||
def with_resolved_context(self, **context):
|
||||
value = resolve_any_context_variables(self.value, **context)
|
||||
def with_resolved_context(
|
||||
self, providers: Optional[ConditionProviderManager] = None, **context
|
||||
):
|
||||
value = resolve_any_context_variables(self.value, providers, **context)
|
||||
return ReturnValueTest(self.comparator, value=value, index=self.index)
|
||||
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ class JsonRpcConditionDict(BaseExecConditionDict):
|
|||
query: NotRequired[str]
|
||||
authorizationToken: NotRequired[str]
|
||||
|
||||
|
||||
#
|
||||
# CompoundCondition represents:
|
||||
# {
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from marshmallow.exceptions import SCHEMA
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.middleware import geth_poa_middleware
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
|
@ -11,6 +13,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
ContextVariableVerificationFailed,
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
InvalidConnectionToChain,
|
||||
InvalidContextVariableData,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
|
@ -22,6 +25,57 @@ from nucypher.utilities.logging import Logger
|
|||
__LOGGER = Logger("condition-eval")
|
||||
|
||||
|
||||
class ConditionProviderManager:
|
||||
def __init__(self, providers: Dict[int, List[HTTPProvider]]):
|
||||
self.providers = providers
|
||||
self.logger = Logger(__name__)
|
||||
|
||||
def web3_endpoints(self, chain_id: int) -> Iterator[Web3]:
|
||||
rpc_providers = self.providers.get(chain_id, None)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=chain_id)
|
||||
|
||||
iterator_returned_at_least_one = False
|
||||
for provider in rpc_providers:
|
||||
try:
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(chain_id, w3)
|
||||
yield w3
|
||||
iterator_returned_at_least_one = True
|
||||
except InvalidConnectionToChain as e:
|
||||
# don't expect to happen but must account
|
||||
# for any misconfigurations of public endpoints
|
||||
self.logger.warn(str(e))
|
||||
|
||||
# if we get here, it is because there were endpoints, but issue with configuring them
|
||||
if not iterator_returned_at_least_one:
|
||||
raise NoConnectionToChain(
|
||||
chain=chain_id,
|
||||
message=f"Problematic provider endpoints for chain ID {chain_id}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_w3(provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
return w3
|
||||
|
||||
@staticmethod
|
||||
def _check_chain_id(chain_id: int, w3: Web3) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = w3.eth.chain_id
|
||||
if provider_chain != chain_id:
|
||||
raise InvalidConnectionToChain(
|
||||
expected_chain=chain_id,
|
||||
actual_chain=provider_chain,
|
||||
)
|
||||
|
||||
|
||||
class ConditionEvalError(Exception):
|
||||
"""Exception when execution condition evaluation."""
|
||||
def __init__(self, message: str, status_code: int):
|
||||
|
@ -58,7 +112,7 @@ class CamelCaseSchema(Schema):
|
|||
|
||||
def evaluate_condition_lingo(
|
||||
condition_lingo: Lingo,
|
||||
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
context: Optional[ContextDict] = None,
|
||||
log: Logger = __LOGGER,
|
||||
):
|
||||
|
@ -74,7 +128,7 @@ def evaluate_condition_lingo(
|
|||
|
||||
# Setup (don't use mutable defaults)
|
||||
context = context or dict()
|
||||
providers = providers or dict()
|
||||
providers = providers or ConditionProviderManager(providers=dict())
|
||||
error = None
|
||||
|
||||
# Evaluate
|
||||
|
@ -142,7 +196,7 @@ def evaluate_condition_lingo(
|
|||
|
||||
|
||||
def extract_single_error_message_from_schema_errors(
|
||||
errors: Dict[str, List[str]]
|
||||
errors: Dict[str, List[str]],
|
||||
) -> str:
|
||||
"""
|
||||
Extract single error message from Schema().validate() errors result.
|
||||
|
|
|
@ -11,12 +11,15 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def condition_providers(testerchain):
|
||||
providers = {testerchain.client.chain_id: {testerchain.provider}}
|
||||
providers = ConditionProviderManager(
|
||||
{testerchain.client.chain_id: {testerchain.provider}}
|
||||
)
|
||||
return providers
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -54,14 +57,12 @@ def erc20_evm_condition_balanceof(testerchain, test_registry, ritual_token):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def erc721_contract(accounts, project):
|
||||
account = accounts[0]
|
||||
|
||||
def erc721_contract(project, deployer_account):
|
||||
# deploy contract
|
||||
deployed_contract = project.ConditionNFT.deploy(sender=account)
|
||||
deployed_contract = project.ConditionNFT.deploy(sender=deployer_account)
|
||||
|
||||
# mint nft with token id = 1
|
||||
deployed_contract.mint(account.address, 1, sender=account)
|
||||
deployed_contract.mint(deployer_account.address, 1, sender=deployer_account)
|
||||
return deployed_contract
|
||||
|
||||
|
||||
|
@ -151,3 +152,11 @@ def custom_context_variable_erc20_condition(
|
|||
parameters=[":addressToUse"],
|
||||
)
|
||||
return condition
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eip1271_contract_wallet(project, deployer_account):
|
||||
_eip1271_contract_wallet = deployer_account.deploy(
|
||||
project.SmartContractWallet, deployer_account.address
|
||||
)
|
||||
return _eip1271_contract_wallet
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from eth_account.messages import defunct_hash_message, encode_defunct
|
||||
from hexbytes import HexBytes
|
||||
from web3 import Web3
|
||||
from web3.providers import BaseProvider
|
||||
|
@ -13,6 +14,7 @@ from nucypher.blockchain.eth.agents import (
|
|||
SubscriptionManagerAgent,
|
||||
)
|
||||
from nucypher.blockchain.eth.constants import NULL_ADDRESS
|
||||
from nucypher.policy.conditions.auth.evm import EvmAuth
|
||||
from nucypher.policy.conditions.context import (
|
||||
USER_ADDRESS_CONTEXT,
|
||||
get_context_value,
|
||||
|
@ -33,6 +35,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
NotCompoundCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import (
|
||||
TEST_ETH_PROVIDER_URI,
|
||||
TEST_POLYGON_PROVIDER_URI,
|
||||
|
@ -67,11 +70,12 @@ def test_rpc_condition_evaluation_no_providers(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
_ = rpc_condition.verify(providers={}, **context)
|
||||
_ = rpc_condition.verify(providers=ConditionProviderManager({}), **context)
|
||||
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
_ = rpc_condition.verify(
|
||||
providers={testerchain.client.chain_id: set()}, **context
|
||||
providers=ConditionProviderManager({testerchain.client.chain_id: list()}),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
|
@ -85,9 +89,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
|
|||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
new_chain = 23
|
||||
rpc_condition.execution_call.chain = new_chain
|
||||
condition_providers = {new_chain: {testerchain.provider}}
|
||||
condition_providers = ConditionProviderManager({new_chain: [testerchain.provider]})
|
||||
with pytest.raises(
|
||||
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
NoConnectionToChain,
|
||||
match=f"Problematic provider endpoints for chain ID {new_chain}",
|
||||
):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -118,13 +123,15 @@ def test_rpc_condition_evaluation_multiple_chain_providers(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
condition_providers = {
|
||||
"1": {"fake1a", "fake1b"},
|
||||
"2": {"fake2"},
|
||||
"3": {"fake3"},
|
||||
"4": {"fake4"},
|
||||
TESTERCHAIN_CHAIN_ID: {testerchain.provider},
|
||||
}
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
"1": ["fake1a", "fake1b"],
|
||||
"2": ["fake2"],
|
||||
"3": ["fake3"],
|
||||
"4": ["fake4"],
|
||||
TESTERCHAIN_CHAIN_ID: [testerchain.provider],
|
||||
}
|
||||
)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
|
@ -144,20 +151,17 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
def my_configure_w3(provider: BaseProvider):
|
||||
return Web3(provider)
|
||||
|
||||
condition_providers = {
|
||||
TESTERCHAIN_CHAIN_ID: {
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
TESTERCHAIN_CHAIN_ID: [
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
|
||||
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
|
||||
with pytest.raises(RPCExecutionFailed):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -171,22 +175,19 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
def my_configure_w3(provider: BaseProvider):
|
||||
return Web3(provider)
|
||||
|
||||
condition_providers = {
|
||||
TESTERCHAIN_CHAIN_ID: {
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
testerchain.provider,
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
TESTERCHAIN_CHAIN_ID: [
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
testerchain.provider,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
|
||||
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
)
|
||||
|
@ -208,10 +209,12 @@ def test_rpc_condition_evaluation_no_connection_to_chain(
|
|||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
# condition providers for other unrelated chains
|
||||
providers = {
|
||||
1: mock.Mock(), # mainnet
|
||||
11155111: mock.Mock(), # Sepolia
|
||||
}
|
||||
providers = ConditionProviderManager(
|
||||
{
|
||||
1: [mock.Mock()], # mainnet
|
||||
11155111: [mock.Mock()], # Sepolia
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
rpc_condition.verify(providers=providers, **context)
|
||||
|
@ -250,7 +253,10 @@ def test_rpc_condition_evaluation_with_context_var_in_return_value_test(
|
|||
invalid_balance = balance + 1
|
||||
context[":balanceContextVar"] = invalid_balance
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers={testerchain.client.chain_id: [testerchain.provider]}, **context
|
||||
providers=ConditionProviderManager(
|
||||
{testerchain.client.chain_id: [testerchain.provider]}
|
||||
),
|
||||
**context,
|
||||
)
|
||||
assert condition_result is False
|
||||
assert call_result != invalid_balance
|
||||
|
@ -926,3 +932,56 @@ def test_json_rpc_condition_non_evm_prototyping_example():
|
|||
)
|
||||
success, _ = condition.verify()
|
||||
assert success
|
||||
|
||||
|
||||
def test_rpc_condition_using_eip1271(
|
||||
deployer_account, eip1271_contract_wallet, condition_providers
|
||||
):
|
||||
# send some ETH to the smart contract wallet
|
||||
eth_amount = Web3.to_wei(2.25, "ether")
|
||||
|
||||
encoded_deposit_function = eip1271_contract_wallet.deposit.encode_input().hex()
|
||||
deployer_account.transfer(
|
||||
account=eip1271_contract_wallet.address,
|
||||
value=eth_amount,
|
||||
data=encoded_deposit_function,
|
||||
)
|
||||
|
||||
rpc_condition = RPCCondition(
|
||||
method="eth_getBalance",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
parameters=[USER_ADDRESS_CONTEXT],
|
||||
return_value_test=ReturnValueTest("==", eth_amount),
|
||||
)
|
||||
|
||||
data = f"I'm the owner of the smart contract wallet address {eip1271_contract_wallet.address}"
|
||||
signable_message = encode_defunct(text=data)
|
||||
hash = defunct_hash_message(text=data)
|
||||
message_signature = deployer_account.sign_message(signable_message)
|
||||
hex_signature = HexBytes(message_signature.encode_rsv()).hex()
|
||||
|
||||
typedData = {"chain": TESTERCHAIN_CHAIN_ID, "dataHash": hash.hex()}
|
||||
auth_message = {
|
||||
"signature": f"{hex_signature}",
|
||||
"address": f"{eip1271_contract_wallet.address}",
|
||||
"scheme": EvmAuth.AuthScheme.EIP1271.value,
|
||||
"typedData": typedData,
|
||||
}
|
||||
context = {
|
||||
USER_ADDRESS_CONTEXT: auth_message,
|
||||
}
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
)
|
||||
assert condition_result is True
|
||||
assert call_result == eth_amount
|
||||
|
||||
# withdraw some ETH and check condition again
|
||||
withdraw_amount = Web3.to_wei(1, "ether")
|
||||
eip1271_contract_wallet.withdraw(withdraw_amount, sender=deployer_account)
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
)
|
||||
assert condition_result is False
|
||||
assert call_result != eth_amount
|
||||
assert call_result == (eth_amount - withdraw_amount)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
|
@ -10,7 +9,8 @@ from nucypher.policy.conditions.lingo import (
|
|||
ConditionType,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
|
||||
from nucypher.policy.conditions.time import TimeCondition
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from nucypher.utilities.logging import GlobalLoggerSettings
|
||||
from tests.utils.policy import make_message_kits
|
||||
|
||||
|
@ -62,7 +62,7 @@ def conditions(bob, multichain_ids):
|
|||
|
||||
|
||||
def test_single_retrieve_with_multichain_conditions(
|
||||
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
|
||||
enacted_policy, bob, multichain_ursulas, conditions, monkeymodule, testerchain
|
||||
):
|
||||
bob.remember_node(multichain_ursulas[0])
|
||||
bob.start_learning_loop()
|
||||
|
@ -72,6 +72,11 @@ def test_single_retrieve_with_multichain_conditions(
|
|||
encrypted_treasure_map=enacted_policy.treasure_map,
|
||||
alice_verifying_key=enacted_policy.publisher_verifying_key,
|
||||
)
|
||||
monkeymodule.setattr(
|
||||
ConditionProviderManager,
|
||||
"web3_endpoints",
|
||||
lambda *args, **kwargs: [testerchain.w3],
|
||||
)
|
||||
|
||||
cleartexts = bob.retrieve_and_decrypt(
|
||||
message_kits=message_kits,
|
||||
|
@ -93,43 +98,30 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
|
|||
alice_verifying_key=enacted_policy.publisher_verifying_key,
|
||||
)
|
||||
|
||||
def _mock_configure_provider(*args, **kwargs):
|
||||
rpc_call_type = args[0]
|
||||
if isinstance(rpc_call_type, TimeRPCCall):
|
||||
# time condition call - only RPCCall is made faulty
|
||||
return testerchain.w3
|
||||
monkeymodule.setattr(
|
||||
ConditionProviderManager,
|
||||
"web3_endpoints",
|
||||
lambda *args, **kwargs: [testerchain.w3, testerchain.w3],
|
||||
) # a base, and fallback
|
||||
|
||||
# rpc condition call
|
||||
provider = args[1]
|
||||
w3 = Web3(provider)
|
||||
return w3
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
|
||||
|
||||
calls = defaultdict(int)
|
||||
rpc_calls = defaultdict(int)
|
||||
original_execute_call = RPCCall._execute
|
||||
|
||||
def faulty_execute_call(*args, **kwargs):
|
||||
def faulty_rpc_execute_call(*args, **kwargs):
|
||||
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
|
||||
nonlocal calls
|
||||
nonlocal rpc_calls
|
||||
rpc_call_object = args[0]
|
||||
resolved_parameters = args[2]
|
||||
calls[rpc_call_object.chain] += 1
|
||||
if calls[rpc_call_object.chain] % 2 == 0:
|
||||
rpc_calls[rpc_call_object.chain] += 1
|
||||
if rpc_calls[rpc_call_object.chain] % 2 == 0:
|
||||
# simulate a network error
|
||||
raise ConnectionError("Something went wrong with the network")
|
||||
|
||||
# replace w3 object with fake provider, with proper w3 object for actual execution
|
||||
return original_execute_call(
|
||||
rpc_call_object, testerchain.w3, resolved_parameters
|
||||
)
|
||||
|
||||
RPCCall._execute = faulty_execute_call
|
||||
# make original call
|
||||
return original_execute_call(*args, **kwargs)
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_execute", faulty_rpc_execute_call)
|
||||
cleartexts = bob.retrieve_and_decrypt(
|
||||
message_kits=message_kits,
|
||||
**policy_info_kwargs,
|
||||
)
|
||||
assert cleartexts == messages
|
||||
|
||||
RPCCall._execute = original_execute_call
|
||||
|
|
|
@ -14,7 +14,6 @@ from nucypher.blockchain.eth.agents import (
|
|||
)
|
||||
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
|
||||
from nucypher.policy.conditions.evm import RPCCall
|
||||
from nucypher.utilities.logging import Logger
|
||||
from tests.constants import (
|
||||
BONUS_TOKENS_FOR_TESTS,
|
||||
|
@ -418,14 +417,6 @@ def taco_child_application_agent(testerchain, test_registry):
|
|||
# Conditions
|
||||
#
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_rpc_condition(testerchain, monkeymodule):
|
||||
def configure_mock(*args, **kwargs):
|
||||
return testerchain.w3
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def multichain_ids(module_mocker):
|
||||
ids = mock_permitted_multichain_connections(mocker=module_mocker)
|
||||
|
@ -433,7 +424,7 @@ def multichain_ids(module_mocker):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def multichain_ursulas(ursulas, multichain_ids, mock_rpc_condition):
|
||||
def multichain_ursulas(ursulas, multichain_ids):
|
||||
setup_multichain_ursulas(ursulas=ursulas, chain_ids=multichain_ids)
|
||||
return ursulas
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
import "@openzeppelin/contracts/access/Ownable.sol";
|
||||
import "@openzeppelin/contracts/interfaces/IERC1271.sol";
|
||||
import "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
|
||||
|
||||
contract SmartContractWallet is IERC1271, Ownable {
|
||||
using ECDSA for bytes32;
|
||||
|
||||
uint public balance;
|
||||
|
||||
bytes4 internal constant MAGICVALUE = 0x1626ba7e;
|
||||
bytes4 constant internal INVALID_SIGNATURE = 0xffffffff;
|
||||
|
||||
constructor(address _owner) Ownable(_owner) public {}
|
||||
|
||||
function deposit() external payable {
|
||||
balance += msg.value;
|
||||
}
|
||||
|
||||
function withdraw(uint amount) external onlyOwner {
|
||||
require(amount <= balance, "Amount exceeds balance");
|
||||
balance -= amount;
|
||||
payable(owner()).transfer(amount);
|
||||
}
|
||||
|
||||
function isValidSignature(bytes32 _hash, bytes memory _signature) public view override returns (bytes4) {
|
||||
address signer = _hash.recover(_signature);
|
||||
if (signer == owner()) {
|
||||
return MAGICVALUE;
|
||||
} else {
|
||||
return INVALID_SIGNATURE;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,5 +18,5 @@ def test_condition_chains_endpoint_multichain(
|
|||
):
|
||||
response = client.get("/condition_chains")
|
||||
assert response.status_code == 200
|
||||
expected_payload = {"version": 1.0, "evm": multichain_ids}
|
||||
expected_payload = {"version": 1.0, "evm": sorted(multichain_ids)}
|
||||
assert response.get_json() == expected_payload
|
||||
|
|
|
@ -17,6 +17,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
CHAIN_ID = 137
|
||||
|
@ -52,7 +53,7 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_return_value = value
|
||||
|
||||
def execute(self, providers: Dict, **context) -> Any:
|
||||
def execute(self, providers: ConditionProviderManager, **context) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
EXECUTION_CALL_TYPE = FakeRPCCall
|
||||
|
@ -125,7 +126,7 @@ def _check_execution_logic(
|
|||
json.dumps(condition_dict)
|
||||
)
|
||||
fake_execution_contract_condition.set_execution_return_value(execution_result)
|
||||
fake_providers = {CHAIN_ID: {Mock(BaseProvider)}}
|
||||
fake_providers = ConditionProviderManager({CHAIN_ID: {Mock(BaseProvider)}})
|
||||
condition_result, call_result = fake_execution_contract_condition.verify(
|
||||
fake_providers, **context
|
||||
)
|
||||
|
|
|
@ -1,16 +1,32 @@
|
|||
import maya
|
||||
import pytest
|
||||
from eth_account import Account
|
||||
from eth_account.messages import defunct_hash_message
|
||||
from hexbytes import HexBytes
|
||||
from siwe import SiweMessage
|
||||
from web3.contract import Contract
|
||||
|
||||
from nucypher.blockchain.eth.signers import InMemorySigner
|
||||
from nucypher.policy.conditions.auth.evm import EIP712Auth, EIP4361Auth, EvmAuth
|
||||
from nucypher.policy.conditions.auth.evm import (
|
||||
EIP712Auth,
|
||||
EIP1271Auth,
|
||||
EIP4361Auth,
|
||||
EvmAuth,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import NoConnectionToChain
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
||||
def test_auth_scheme():
|
||||
expected_schemes = {
|
||||
EvmAuth.AuthScheme.EIP712: EIP712Auth,
|
||||
EvmAuth.AuthScheme.EIP4361: EIP4361Auth,
|
||||
EvmAuth.AuthScheme.EIP1271: EIP1271Auth,
|
||||
}
|
||||
|
||||
for scheme in EvmAuth.AuthScheme:
|
||||
expected_scheme = (
|
||||
EIP712Auth if scheme == EvmAuth.AuthScheme.EIP712 else EIP4361Auth
|
||||
)
|
||||
expected_scheme = expected_schemes.get(scheme)
|
||||
assert EvmAuth.from_scheme(scheme=scheme.value) == expected_scheme
|
||||
|
||||
# non-existent scheme
|
||||
|
@ -279,3 +295,119 @@ def test_authenticate_eip4361(get_random_checksum_address):
|
|||
not_stale_but_past_expiry_signature.hex(),
|
||||
valid_address_for_signature,
|
||||
)
|
||||
|
||||
|
||||
def test_authenticate_eip1271(mocker, get_random_checksum_address):
|
||||
# smart contract wallet
|
||||
eip1271_mock_contract = mocker.Mock(spec=Contract)
|
||||
contract_address = get_random_checksum_address()
|
||||
eip1271_mock_contract.address = contract_address
|
||||
|
||||
# signer for wallet
|
||||
data = f"I'm the owner of the smart contract wallet address {eip1271_mock_contract.address}"
|
||||
wallet_signer = InMemorySigner()
|
||||
valid_message_signature = wallet_signer.sign_message(
|
||||
account=wallet_signer.accounts[0], message=data.encode()
|
||||
)
|
||||
data_hash = defunct_hash_message(text=data)
|
||||
typedData = {"chain": TESTERCHAIN_CHAIN_ID, "dataHash": data_hash.hex()}
|
||||
|
||||
def _isValidSignature(data_hash, signature_bytes):
|
||||
class ContractCall:
|
||||
def __init__(self, hash, signature):
|
||||
self.hash = hash
|
||||
self.signature = signature
|
||||
|
||||
def call(self):
|
||||
recovered_address = Account._recover_hash(
|
||||
message_hash=self.hash, signature=self.signature
|
||||
)
|
||||
if recovered_address == wallet_signer.accounts[0]:
|
||||
return bytes(HexBytes("0x1626ba7e"))
|
||||
|
||||
return bytes(HexBytes("0xffffffff"))
|
||||
|
||||
return ContractCall(data_hash, signature_bytes)
|
||||
|
||||
eip1271_mock_contract.functions.isValidSignature.side_effect = _isValidSignature
|
||||
|
||||
# condition provider manager
|
||||
providers = mocker.Mock(spec=ConditionProviderManager)
|
||||
w3 = mocker.Mock()
|
||||
w3.eth.contract.return_value = eip1271_mock_contract
|
||||
providers.web3_endpoints.return_value = [w3]
|
||||
|
||||
# valid signature
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, providers
|
||||
)
|
||||
w3.eth.contract.assert_called_once_with(
|
||||
address=eip1271_mock_contract.address, abi=EIP1271Auth.EIP1271_ABI
|
||||
)
|
||||
|
||||
# no providers
|
||||
with pytest.raises(EvmAuth.AuthenticationFailed, match="no endpoints provided"):
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, None
|
||||
)
|
||||
|
||||
# invalid typed data - no chain id
|
||||
with pytest.raises(EvmAuth.InvalidData):
|
||||
EIP1271Auth.authenticate(
|
||||
{
|
||||
"dataHash": data_hash.hex(),
|
||||
},
|
||||
valid_message_signature,
|
||||
eip1271_mock_contract.address,
|
||||
providers,
|
||||
)
|
||||
|
||||
# invalid typed data - no data hash
|
||||
with pytest.raises(EvmAuth.InvalidData):
|
||||
EIP1271Auth.authenticate(
|
||||
{
|
||||
"chainId": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
valid_message_signature,
|
||||
eip1271_mock_contract.address,
|
||||
providers,
|
||||
)
|
||||
|
||||
# use invalid signer
|
||||
invalid_signer = InMemorySigner()
|
||||
invalid_message_signature = invalid_signer.sign_message(
|
||||
account=invalid_signer.accounts[0], message=data.encode()
|
||||
)
|
||||
with pytest.raises(EvmAuth.AuthenticationFailed):
|
||||
EIP1271Auth.authenticate(
|
||||
typedData,
|
||||
invalid_message_signature,
|
||||
eip1271_mock_contract.address,
|
||||
providers,
|
||||
)
|
||||
|
||||
# bad w3 instance failed for some reason
|
||||
w3_bad = mocker.Mock()
|
||||
w3_bad.eth.contract.side_effect = ValueError("something went wrong")
|
||||
providers.web3_endpoints.return_value = [w3_bad]
|
||||
with pytest.raises(EvmAuth.AuthenticationFailed, match="something went wrong"):
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, providers
|
||||
)
|
||||
assert w3_bad.eth.contract.call_count == 1, "one call that failed"
|
||||
|
||||
# fall back to good w3 instances
|
||||
providers.web3_endpoints.return_value = [w3_bad, w3_bad, w3]
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, providers
|
||||
)
|
||||
assert w3_bad.eth.contract.call_count == 3, "two more calls that failed"
|
||||
|
||||
# no connection to chain
|
||||
providers.web3_endpoints.side_effect = NoConnectionToChain(
|
||||
chain=TESTERCHAIN_CHAIN_ID
|
||||
)
|
||||
with pytest.raises(EvmAuth.AuthenticationFailed, match="No connection to chain ID"):
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, providers
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -248,7 +249,9 @@ def test_nested_multi_conditions(mock_conditions):
|
|||
else_condition=False,
|
||||
)
|
||||
|
||||
result, value = if_then_else_condition.verify(providers={})
|
||||
result, value = if_then_else_condition.verify(
|
||||
providers=ConditionProviderManager({})
|
||||
)
|
||||
assert result is True
|
||||
assert value == [[1, 2], [2, 3]] # [[or result], [seq result]]
|
||||
|
||||
|
@ -277,7 +280,9 @@ def test_nested_multi_conditions(mock_conditions):
|
|||
),
|
||||
)
|
||||
|
||||
result, value = if_then_else_condition.verify(providers={})
|
||||
result, value = if_then_else_condition.verify(
|
||||
providers=ConditionProviderManager({})
|
||||
)
|
||||
assert result is False
|
||||
assert value == [[1, 2], [3, 2]] # [[or result], [else if condition result]]
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -173,7 +174,9 @@ def test_sequential_condition(mock_condition_variables):
|
|||
)
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
result, value = sequential_condition.verify(
|
||||
providers=ConditionProviderManager({}), **original_context
|
||||
)
|
||||
assert result is True
|
||||
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
|
||||
# only a copy of the context is modified internally
|
||||
|
@ -215,7 +218,9 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
|
|||
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
result, value = sequential_condition.verify(
|
||||
providers=ConditionProviderManager({}), **original_context
|
||||
)
|
||||
assert result is True
|
||||
assert value == [
|
||||
expected_var_1_value,
|
||||
|
@ -238,4 +243,4 @@ def test_sequential_condition_a_call_fails(mock_condition_variables):
|
|||
)
|
||||
|
||||
with pytest.raises(Web3Exception):
|
||||
_ = sequential_condition.verify(providers={})
|
||||
_ = sequential_condition.verify(providers=ConditionProviderManager({}))
|
||||
|
|
|
@ -37,6 +37,7 @@ from nucypher.policy.conditions.lingo import ConditionLingo
|
|||
from nucypher.policy.conditions.utils import (
|
||||
CamelCaseSchema,
|
||||
ConditionEvalError,
|
||||
ConditionProviderManager,
|
||||
camel_case_to_snake,
|
||||
evaluate_condition_lingo,
|
||||
to_camelcase,
|
||||
|
@ -102,7 +103,9 @@ def test_evaluate_condition_eval_returns_false():
|
|||
with pytest.raises(ConditionEvalError) as eval_error:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
providers=ConditionProviderManager(
|
||||
{1: Mock(spec=BaseProvider)}
|
||||
), # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
|
||||
|
@ -119,10 +122,12 @@ def test_evaluate_condition_eval_returns_true():
|
|||
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}, # multiple fake provider
|
||||
providers=ConditionProviderManager(
|
||||
{
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}
|
||||
),
|
||||
context={
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
|
@ -166,3 +171,48 @@ def test_camel_case_schema():
|
|||
|
||||
reloaded_function = schema.load(output)
|
||||
assert reloaded_function == {"field_name_with_underscores": f"{value}"}
|
||||
|
||||
|
||||
def test_condition_provider_manager(mocker):
|
||||
# no condition to chain
|
||||
with pytest.raises(NoConnectionToChain, match="No connection to chain ID"):
|
||||
manager = ConditionProviderManager(
|
||||
providers={2: [mocker.Mock(spec=BaseProvider)]}
|
||||
)
|
||||
_ = list(manager.web3_endpoints(chain_id=1))
|
||||
|
||||
# invalid provider chain
|
||||
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
|
||||
w3 = mocker.Mock()
|
||||
w3.eth.chain_id = (
|
||||
1 # make w3 instance created from provider have incorrect chain id
|
||||
)
|
||||
with patch.object(manager, "_configure_w3", return_value=w3):
|
||||
with pytest.raises(
|
||||
NoConnectionToChain, match="Problematic provider endpoints for chain ID"
|
||||
):
|
||||
_ = list(manager.web3_endpoints(chain_id=2))
|
||||
|
||||
# valid provider chain
|
||||
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
|
||||
with patch.object(manager, "_check_chain_id", return_value=None):
|
||||
assert len(list(manager.web3_endpoints(chain_id=2))) == 1
|
||||
|
||||
# multiple providers
|
||||
manager = ConditionProviderManager(
|
||||
providers={2: [mocker.Mock(spec=BaseProvider), mocker.Mock(spec=BaseProvider)]}
|
||||
)
|
||||
with patch.object(manager, "_check_chain_id", return_value=None):
|
||||
w3_instances = list(manager.web3_endpoints(chain_id=2))
|
||||
assert len(w3_instances) == 2
|
||||
for w3_instance in w3_instances:
|
||||
assert w3_instance # actual object returned
|
||||
assert w3_instance.middleware_onion.get("poa") # poa middleware injected
|
||||
|
||||
# specific w3 instances
|
||||
w3_1 = mocker.Mock()
|
||||
w3_1.eth.chain_id = 2
|
||||
w3_2 = mocker.Mock()
|
||||
w3_2.eth.chain_id = 2
|
||||
with patch.object(manager, "_configure_w3", side_effect=[w3_1, w3_2]):
|
||||
assert list(manager.web3_endpoints(chain_id=2)) == [w3_1, w3_2]
|
||||
|
|
|
@ -11,6 +11,7 @@ from web3 import HTTPProvider
|
|||
from nucypher.blockchain.eth.signers import InMemorySigner, Signer
|
||||
from nucypher.characters.lawful import Ursula
|
||||
from nucypher.config.characters import UrsulaConfiguration
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
from tests.utils.blockchain import ReservedTestAccountManager
|
||||
|
||||
|
@ -176,14 +177,16 @@ def setup_multichain_ursulas(chain_ids: List[int], ursulas: List[Ursula]) -> Non
|
|||
fallback_blockchain_endpoints = [
|
||||
base_fallback_uri.format(i) for i in range(len(chain_ids))
|
||||
]
|
||||
mocked_condition_providers = {
|
||||
cid: {HTTPProvider(uri), HTTPProvider(furi)}
|
||||
for cid, uri, furi in zip(
|
||||
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
|
||||
)
|
||||
}
|
||||
mocked_condition_providers = ConditionProviderManager(
|
||||
{
|
||||
cid: [HTTPProvider(uri), HTTPProvider(furi)]
|
||||
for cid, uri, furi in zip(
|
||||
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
|
||||
)
|
||||
}
|
||||
)
|
||||
for ursula in ursulas:
|
||||
ursula.condition_providers = mocked_condition_providers
|
||||
ursula.condition_provider_manager = mocked_condition_providers
|
||||
|
||||
|
||||
MOCK_KNOWN_URSULAS_CACHE = dict()
|
||||
|
|
Loading…
Reference in New Issue