Merge pull request #3576 from derekpierre/eip1271-support

EIP1271 Support
pull/3581/head lynx
Derek Pierre 2025-02-10 08:51:12 -05:00 committed by GitHub
commit 07e13b9930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 667 additions and 200 deletions

View File

@ -0,0 +1 @@
Add support for EIP1271 signature verification for smart contract wallets.

View File

@ -4,7 +4,7 @@ import time
import traceback
from collections import defaultdict
from decimal import Decimal
from typing import DefaultDict, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import maya
from atxm.exceptions import InsufficientFunds
@ -65,7 +65,10 @@ from nucypher.crypto.powers import (
TransactingPower,
)
from nucypher.datastore.dkg import DKGStorage
from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.policy.conditions.utils import (
ConditionProviderManager,
evaluate_condition_lingo,
)
from nucypher.policy.payment import ContractPayment
from nucypher.types import PhaseId
from nucypher.utilities.emitters import StdoutEmitter
@ -247,7 +250,7 @@ class Operator(BaseActor):
ThresholdRequestDecryptingPower
) # used for secure decryption request channel
self.condition_providers = self.connect_condition_providers(
self.condition_provider_manager = self.get_condition_provider_manager(
condition_blockchain_endpoints
)
@ -269,9 +272,9 @@ class Operator(BaseActor):
provider = HTTPProvider(endpoint_uri=uri)
return provider
def connect_condition_providers(
def get_condition_provider_manager(
self, operator_configured_endpoints: Dict[int, List[str]]
) -> DefaultDict[int, List[HTTPProvider]]:
) -> ConditionProviderManager:
# check that we have mandatory user configured endpoints
mandatory_configured_chains = {
@ -336,7 +339,7 @@ class Operator(BaseActor):
f"checking on chain IDs {providers.keys()}"
)
return providers
return ConditionProviderManager(providers=providers)
def _resolve_ritual(self, ritual_id: int) -> Coordinator.Ritual:
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
@ -845,7 +848,7 @@ class Operator(BaseActor):
evaluate_condition_lingo(
condition_lingo=condition_lingo,
context=context,
providers=self.condition_providers,
providers=self.condition_provider_manager,
)
def _verify_decryption_request_authorization(

View File

@ -154,8 +154,9 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
"""
# TODO: When non-evm chains are supported, bump the version.
# this can return a list of chain names or other verifiable identifiers.
payload = {"version": 1.0, "evm": list(this_node.condition_providers)}
providers = this_node.condition_provider_manager.providers
sorted_chain_ids = sorted(list(providers))
payload = {"version": 1.0, "evm": sorted_chain_ids}
return Response(json.dumps(payload), mimetype="application/json")
@rest_app.route('/decrypt', methods=["POST"])
@ -260,7 +261,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
try:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=this_node.condition_providers,
providers=this_node.condition_provider_manager,
context=context,
)
except ConditionEvalError as error:

View File

@ -1,16 +1,22 @@
from enum import Enum
from typing import List
from typing import List, Optional
import maya
from eth_account.account import Account
from eth_account.messages import HexBytes, encode_typed_data
from eth_typing import ChecksumAddress
from siwe import SiweMessage, VerificationError
from nucypher.policy.conditions.exceptions import NoConnectionToChain
from nucypher.policy.conditions.utils import ConditionProviderManager
from nucypher.utilities.logging import Logger
class EvmAuth:
class AuthScheme(Enum):
EIP712 = "EIP712"
EIP4361 = "EIP4361"
EIP1271 = "EIP1271"
@classmethod
def values(cls) -> List[str]:
@ -26,7 +32,13 @@ class EvmAuth:
"""The message is too old."""
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
raise NotImplementedError
@classmethod
@ -35,13 +47,21 @@ class EvmAuth:
return EIP712Auth
elif scheme == cls.AuthScheme.EIP4361.value:
return EIP4361Auth
elif scheme == cls.AuthScheme.EIP1271.value:
return EIP1271Auth
raise ValueError(f"Invalid authentication scheme: {scheme}")
class EIP712Auth(EvmAuth):
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
try:
# convert hex data for byte fields - bytes are expected by underlying library
# 1. salt
@ -72,7 +92,13 @@ class EIP4361Auth(EvmAuth):
FRESHNESS_IN_HOURS = 2
@classmethod
def authenticate(cls, data, signature, expected_address):
def authenticate(
cls,
data,
signature: str,
expected_address: str,
providers: Optional[ConditionProviderManager] = None,
):
try:
siwe_message = SiweMessage.from_message(message=data)
except Exception as e:
@ -106,3 +132,113 @@ class EIP4361Auth(EvmAuth):
raise cls.AuthenticationFailed(
f"Invalid EIP4361 signature; signature not valid for expected address, {expected_address}"
)
class EIP1271Auth(EvmAuth):
EIP1271_ABI = """[
{
"inputs":[
{
"internalType":"bytes32",
"name":"_hash",
"type":"bytes32"
},
{
"internalType":"bytes",
"name":"_signature",
"type":"bytes"
}
],
"name":"isValidSignature",
"outputs":[
{
"internalType":"bytes4",
"name":"",
"type":"bytes4"
}
],
"stateMutability":"view",
"type":"function"
}
]"""
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
LOG = Logger("EIP1271Auth")
@classmethod
def _extract_typed_data(cls, data):
try:
data_hash = bytes(HexBytes(data["dataHash"]))
chain = data["chain"]
return data_hash, chain
except Exception as e:
# data could not be processed
raise cls.InvalidData(
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
)
@classmethod
def _validate_auth_data(
cls, data_hash, signature_bytes, expected_address, chain, providers
):
web3_endpoints = providers.web3_endpoints(chain_id=chain)
last_error = None
for web3_instance in web3_endpoints:
try:
# Interact with the EIP1271 contract
eip1271_contract = web3_instance.eth.contract(
address=expected_address, abi=cls.EIP1271_ABI
)
result = eip1271_contract.functions.isValidSignature(
data_hash,
signature_bytes,
).call()
if result == cls.MAGIC_VALUE_BYTES:
return # Successful authentication
break
except Exception as e:
last_error = f"EIP1271 contract call failed ({expected_address}): {e}"
cls.LOG.warn(f"{last_error}; attempting next provider")
else:
# If all providers fail
if last_error:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; {last_error}"
)
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
)
@classmethod
def authenticate(
cls,
data,
signature: str,
expected_address: ChecksumAddress,
providers: Optional[ConditionProviderManager] = None,
):
if not providers:
# should never happen
raise cls.AuthenticationFailed(
"EIP1271 verification failed; no endpoints provided"
)
# Extract and validate input data
data_hash, chain = cls._extract_typed_data(data)
# Validate the signature
signature_bytes = bytes(HexBytes(signature))
try:
cls._validate_auth_data(
data_hash, signature_bytes, expected_address, chain, providers
)
except NoConnectionToChain:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; No connection to chain ID {chain}"
)
except cls.AuthenticationFailed:
raise
except Exception as e:
# catch all
raise cls.AuthenticationFailed(f"EIP1271 verification failed; {e}")

View File

@ -1,6 +1,6 @@
import re
from functools import partial
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
@ -11,6 +11,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidContextVariableData,
RequiredContextVariable,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
USER_ADDRESS_CONTEXT = ":userAddress"
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT = ":userAddressExternalEIP4361"
@ -19,7 +20,7 @@ CONTEXT_PREFIX = ":"
CONTEXT_REGEX = re.compile(":[a-zA-Z_][a-zA-Z0-9_]*")
USER_ADDRESS_SCHEMES = {
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP712) for now; eventually EIP712 will be deprecated
USER_ADDRESS_CONTEXT: None, # allow any scheme (EIP4361, EIP1271, EIP712) for now; eventually EIP712 will be deprecated
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT: EvmAuth.AuthScheme.EIP4361.value,
}
@ -28,7 +29,11 @@ class UnexpectedScheme(Exception):
pass
def _resolve_user_address(user_address_context_variable, **context) -> ChecksumAddress:
def _resolve_user_address(
user_address_context_variable: str,
providers: Optional[ConditionProviderManager] = None,
**context,
) -> ChecksumAddress:
"""
Recovers a checksum address from a signed message.
@ -38,7 +43,7 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
{
"signature": "<signature>",
"address": "<address>",
"scheme": "EIP4361" | ...
"scheme": "EIP4361" | "EIP1271" | ...
"typedData": ...
}
}
@ -59,7 +64,10 @@ def _resolve_user_address(user_address_context_variable, **context) -> ChecksumA
auth = EvmAuth.from_scheme(scheme)
auth.authenticate(
data=typed_data, signature=signature, expected_address=expected_address
data=typed_data,
signature=signature,
expected_address=expected_address,
providers=providers,
)
except EvmAuth.InvalidData as e:
raise InvalidContextVariableData(
@ -98,7 +106,11 @@ def string_contains_context_variable(variable: str) -> bool:
return bool(matches)
def get_context_value(context_variable: str, **context) -> Any:
def get_context_value(
context_variable: str,
providers: Optional[ConditionProviderManager] = None,
**context,
) -> Any:
try:
# DIRECTIVES are special context vars that will pre-processed by ursula
func = _DIRECTIVES[context_variable]
@ -111,33 +123,40 @@ def get_context_value(context_variable: str, **context) -> Any:
f'No value provided for unrecognized context variable "{context_variable}"'
)
else:
value = func(**context) # required inputs here
value = func(providers=providers, **context) # required inputs here
return value
def resolve_any_context_variables(
param: Union[Any, List[Any], Dict[Any, Any]], **context
param: Union[Any, List[Any], Dict[Any, Any]],
providers: Optional[ConditionProviderManager] = None,
**context,
):
if isinstance(param, list):
return [resolve_any_context_variables(item, **context) for item in param]
return [
resolve_any_context_variables(item, providers, **context) for item in param
]
elif isinstance(param, dict):
return {
k: resolve_any_context_variables(v, **context) for k, v in param.items()
k: resolve_any_context_variables(v, providers, **context)
for k, v in param.items()
}
elif isinstance(param, str):
# either it is a context variable OR contains a context variable within it
# TODO separating the two cases for now out of concern of regex searching
# within strings (case 2)
if is_context_variable(param):
return get_context_value(context_variable=param, **context)
return get_context_value(
context_variable=param, providers=providers, **context
)
else:
matches = re.findall(CONTEXT_REGEX, param)
for context_var in matches:
# checking out of concern for faulty regex search within string
if context_var in context:
resolved_var = get_context_value(
context_variable=context_var, **context
context_variable=context_var, providers=providers, **context
)
param = param.replace(context_var, str(resolved_var))
return param

View File

@ -1,10 +1,7 @@
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
)
@ -20,9 +17,7 @@ from marshmallow import (
)
from marshmallow.validate import OneOf
from typing_extensions import override
from web3 import HTTPProvider, Web3
from web3.middleware import geth_poa_middleware
from web3.providers import BaseProvider
from web3 import Web3
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
@ -34,7 +29,6 @@ from nucypher.policy.conditions.context import (
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
)
@ -43,7 +37,10 @@ from nucypher.policy.conditions.lingo import (
ExecutionCallAccessControlCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.policy.conditions.utils import (
ConditionProviderManager,
camel_case_to_snake,
)
from nucypher.policy.conditions.validation import (
align_comparator_value_with_abi,
get_unbound_contract_function,
@ -105,56 +102,17 @@ class RPCCall(ExecutionCall):
) # bind contract function (only exposes the eth API)
return rpc_function
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
def _check_chain_id(self, w3: Web3) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = w3.eth.chain_id
if provider_chain != self.chain:
raise NoConnectionToChain(
chain=self.chain,
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}",
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
w3 = self._configure_w3(provider=provider)
self._check_chain_id(w3)
return w3
def _next_endpoint(
self, providers: Dict[int, Set[HTTPProvider]]
) -> Iterator[HTTPProvider]:
"""Yields the next web3 provider to try for a given chain ID"""
rpc_providers = providers.get(self.chain, None)
if not rpc_providers:
raise NoConnectionToChain(chain=self.chain)
for provider in rpc_providers:
# Someday, we might make this whole function async, and then we can knock on
# each endpoint here to see if it's alive and only yield it if it is.
yield provider
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
def execute(self, providers: ConditionProviderManager, **context) -> Any:
resolved_parameters = []
if self.parameters:
resolved_parameters = resolve_any_context_variables(
self.parameters, **context
param=self.parameters, providers=providers, **context
)
endpoints = self._next_endpoint(providers=providers)
endpoints = providers.web3_endpoints(self.chain)
latest_error = ""
for provider in endpoints:
w3 = self._configure_provider(provider)
for w3 in endpoints:
try:
result = self._execute(w3, resolved_parameters)
break
@ -255,10 +213,10 @@ class RPCCondition(ExecutionCallAccessControlCondition):
return return_value_test
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
self, providers: ConditionProviderManager, **context
) -> Tuple[bool, Any]:
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
providers=providers, **context
)
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test

View File

@ -13,6 +13,19 @@ class NoConnectionToChain(RuntimeError):
super().__init__(message)
class InvalidConnectionToChain(RuntimeError):
"""Raised when a node does not have a valid provider for a chain."""
def __init__(self, expected_chain: int, actual_chain: int, message: str = None):
self.expected_chain = expected_chain
self.actual_chain = actual_chain
message = (
message
or f"Invalid blockchain connection; expected chain ID {expected_chain}, but detected {actual_chain}"
)
super().__init__(message)
class ReturnValueEvaluationError(Exception):
"""Issue with Return Value and Key"""

View File

@ -4,7 +4,7 @@ import json
import operator as pyoperator
from enum import Enum
from hashlib import md5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, List, Optional, Tuple, Type, Union
from hexbytes import HexBytes
from marshmallow import (
@ -19,7 +19,6 @@ from marshmallow import (
)
from marshmallow.validate import OneOf, Range
from packaging.version import parse as parse_version
from web3 import HTTPProvider
from nucypher.policy.conditions.base import (
AccessControlCondition,
@ -37,7 +36,7 @@ from nucypher.policy.conditions.exceptions import (
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, Lingo
from nucypher.policy.conditions.utils import CamelCaseSchema
from nucypher.policy.conditions.utils import CamelCaseSchema, ConditionProviderManager
class _ConditionField(fields.Dict):
@ -339,7 +338,7 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
# TODO - think about not dereferencing context but using a dict;
# may allows more freedom for params
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
self, providers: ConditionProviderManager, **context
) -> Tuple[bool, Any]:
values = []
latest_success = False
@ -611,8 +610,10 @@ class ReturnValueTest:
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
return result
def with_resolved_context(self, **context):
value = resolve_any_context_variables(self.value, **context)
def with_resolved_context(
self, providers: Optional[ConditionProviderManager] = None, **context
):
value = resolve_any_context_variables(self.value, providers, **context)
return ReturnValueTest(self.comparator, value=value, index=self.index)

View File

@ -74,6 +74,7 @@ class JsonRpcConditionDict(BaseExecConditionDict):
query: NotRequired[str]
authorizationToken: NotRequired[str]
#
# CompoundCondition represents:
# {

View File

@ -1,9 +1,11 @@
import re
from http import HTTPStatus
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
from marshmallow import Schema, post_dump
from marshmallow.exceptions import SCHEMA
from web3 import HTTPProvider, Web3
from web3.middleware import geth_poa_middleware
from web3.providers import BaseProvider
from nucypher.policy.conditions.exceptions import (
@ -11,6 +13,7 @@ from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
InvalidCondition,
InvalidConditionLingo,
InvalidConnectionToChain,
InvalidContextVariableData,
NoConnectionToChain,
RequiredContextVariable,
@ -22,6 +25,57 @@ from nucypher.utilities.logging import Logger
__LOGGER = Logger("condition-eval")
class ConditionProviderManager:
def __init__(self, providers: Dict[int, List[HTTPProvider]]):
self.providers = providers
self.logger = Logger(__name__)
def web3_endpoints(self, chain_id: int) -> Iterator[Web3]:
rpc_providers = self.providers.get(chain_id, None)
if not rpc_providers:
raise NoConnectionToChain(chain=chain_id)
iterator_returned_at_least_one = False
for provider in rpc_providers:
try:
w3 = self._configure_w3(provider=provider)
self._check_chain_id(chain_id, w3)
yield w3
iterator_returned_at_least_one = True
except InvalidConnectionToChain as e:
# don't expect to happen but must account
# for any misconfigurations of public endpoints
self.logger.warn(str(e))
# if we get here, it is because there were endpoints, but issue with configuring them
if not iterator_returned_at_least_one:
raise NoConnectionToChain(
chain=chain_id,
message=f"Problematic provider endpoints for chain ID {chain_id}",
)
@staticmethod
def _configure_w3(provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
@staticmethod
def _check_chain_id(chain_id: int, w3: Web3) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = w3.eth.chain_id
if provider_chain != chain_id:
raise InvalidConnectionToChain(
expected_chain=chain_id,
actual_chain=provider_chain,
)
class ConditionEvalError(Exception):
"""Exception when execution condition evaluation."""
def __init__(self, message: str, status_code: int):
@ -58,7 +112,7 @@ class CamelCaseSchema(Schema):
def evaluate_condition_lingo(
condition_lingo: Lingo,
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
providers: Optional[ConditionProviderManager] = None,
context: Optional[ContextDict] = None,
log: Logger = __LOGGER,
):
@ -74,7 +128,7 @@ def evaluate_condition_lingo(
# Setup (don't use mutable defaults)
context = context or dict()
providers = providers or dict()
providers = providers or ConditionProviderManager(providers=dict())
error = None
# Evaluate
@ -142,7 +196,7 @@ def evaluate_condition_lingo(
def extract_single_error_message_from_schema_errors(
errors: Dict[str, List[str]]
errors: Dict[str, List[str]],
) -> str:
"""
Extract single error message from Schema().validate() errors result.

View File

@ -11,12 +11,15 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
@pytest.fixture()
def condition_providers(testerchain):
providers = {testerchain.client.chain_id: {testerchain.provider}}
providers = ConditionProviderManager(
{testerchain.client.chain_id: {testerchain.provider}}
)
return providers
@pytest.fixture()
@ -54,14 +57,12 @@ def erc20_evm_condition_balanceof(testerchain, test_registry, ritual_token):
@pytest.fixture
def erc721_contract(accounts, project):
account = accounts[0]
def erc721_contract(project, deployer_account):
# deploy contract
deployed_contract = project.ConditionNFT.deploy(sender=account)
deployed_contract = project.ConditionNFT.deploy(sender=deployer_account)
# mint nft with token id = 1
deployed_contract.mint(account.address, 1, sender=account)
deployed_contract.mint(deployer_account.address, 1, sender=deployer_account)
return deployed_contract
@ -151,3 +152,11 @@ def custom_context_variable_erc20_condition(
parameters=[":addressToUse"],
)
return condition
@pytest.fixture
def eip1271_contract_wallet(project, deployer_account):
_eip1271_contract_wallet = deployer_account.deploy(
project.SmartContractWallet, deployer_account.address
)
return _eip1271_contract_wallet

View File

@ -3,6 +3,7 @@ import os
from unittest import mock
import pytest
from eth_account.messages import defunct_hash_message, encode_defunct
from hexbytes import HexBytes
from web3 import Web3
from web3.providers import BaseProvider
@ -13,6 +14,7 @@ from nucypher.blockchain.eth.agents import (
SubscriptionManagerAgent,
)
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.auth.evm import EvmAuth
from nucypher.policy.conditions.context import (
USER_ADDRESS_CONTEXT,
get_context_value,
@ -33,6 +35,7 @@ from nucypher.policy.conditions.lingo import (
NotCompoundCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import (
TEST_ETH_PROVIDER_URI,
TEST_POLYGON_PROVIDER_URI,
@ -67,11 +70,12 @@ def test_rpc_condition_evaluation_no_providers(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
with pytest.raises(NoConnectionToChain):
_ = rpc_condition.verify(providers={}, **context)
_ = rpc_condition.verify(providers=ConditionProviderManager({}), **context)
with pytest.raises(NoConnectionToChain):
_ = rpc_condition.verify(
providers={testerchain.client.chain_id: set()}, **context
providers=ConditionProviderManager({testerchain.client.chain_id: list()}),
**context,
)
@ -85,9 +89,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
new_chain = 23
rpc_condition.execution_call.chain = new_chain
condition_providers = {new_chain: {testerchain.provider}}
condition_providers = ConditionProviderManager({new_chain: [testerchain.provider]})
with pytest.raises(
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
NoConnectionToChain,
match=f"Problematic provider endpoints for chain ID {new_chain}",
):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -118,13 +123,15 @@ def test_rpc_condition_evaluation_multiple_chain_providers(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
condition_providers = {
"1": {"fake1a", "fake1b"},
"2": {"fake2"},
"3": {"fake3"},
"4": {"fake4"},
TESTERCHAIN_CHAIN_ID: {testerchain.provider},
}
condition_providers = ConditionProviderManager(
{
"1": ["fake1a", "fake1b"],
"2": ["fake2"],
"3": ["fake3"],
"4": ["fake4"],
TESTERCHAIN_CHAIN_ID: [testerchain.provider],
}
)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
@ -144,20 +151,17 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
def my_configure_w3(provider: BaseProvider):
return Web3(provider)
condition_providers = {
TESTERCHAIN_CHAIN_ID: {
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
condition_providers = ConditionProviderManager(
{
TESTERCHAIN_CHAIN_ID: [
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
]
}
}
mocker.patch.object(
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
with pytest.raises(RPCExecutionFailed):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -171,22 +175,19 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
def my_configure_w3(provider: BaseProvider):
return Web3(provider)
condition_providers = {
TESTERCHAIN_CHAIN_ID: {
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
testerchain.provider,
condition_providers = ConditionProviderManager(
{
TESTERCHAIN_CHAIN_ID: [
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
testerchain.provider,
]
}
}
mocker.patch.object(
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
)
@ -208,10 +209,12 @@ def test_rpc_condition_evaluation_no_connection_to_chain(
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
# condition providers for other unrelated chains
providers = {
1: mock.Mock(), # mainnet
11155111: mock.Mock(), # Sepolia
}
providers = ConditionProviderManager(
{
1: [mock.Mock()], # mainnet
11155111: [mock.Mock()], # Sepolia
}
)
with pytest.raises(NoConnectionToChain):
rpc_condition.verify(providers=providers, **context)
@ -250,7 +253,10 @@ def test_rpc_condition_evaluation_with_context_var_in_return_value_test(
invalid_balance = balance + 1
context[":balanceContextVar"] = invalid_balance
condition_result, call_result = rpc_condition.verify(
providers={testerchain.client.chain_id: [testerchain.provider]}, **context
providers=ConditionProviderManager(
{testerchain.client.chain_id: [testerchain.provider]}
),
**context,
)
assert condition_result is False
assert call_result != invalid_balance
@ -926,3 +932,56 @@ def test_json_rpc_condition_non_evm_prototyping_example():
)
success, _ = condition.verify()
assert success
def test_rpc_condition_using_eip1271(
deployer_account, eip1271_contract_wallet, condition_providers
):
# send some ETH to the smart contract wallet
eth_amount = Web3.to_wei(2.25, "ether")
encoded_deposit_function = eip1271_contract_wallet.deposit.encode_input().hex()
deployer_account.transfer(
account=eip1271_contract_wallet.address,
value=eth_amount,
data=encoded_deposit_function,
)
rpc_condition = RPCCondition(
method="eth_getBalance",
chain=TESTERCHAIN_CHAIN_ID,
parameters=[USER_ADDRESS_CONTEXT],
return_value_test=ReturnValueTest("==", eth_amount),
)
data = f"I'm the owner of the smart contract wallet address {eip1271_contract_wallet.address}"
signable_message = encode_defunct(text=data)
hash = defunct_hash_message(text=data)
message_signature = deployer_account.sign_message(signable_message)
hex_signature = HexBytes(message_signature.encode_rsv()).hex()
typedData = {"chain": TESTERCHAIN_CHAIN_ID, "dataHash": hash.hex()}
auth_message = {
"signature": f"{hex_signature}",
"address": f"{eip1271_contract_wallet.address}",
"scheme": EvmAuth.AuthScheme.EIP1271.value,
"typedData": typedData,
}
context = {
USER_ADDRESS_CONTEXT: auth_message,
}
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
)
assert condition_result is True
assert call_result == eth_amount
# withdraw some ETH and check condition again
withdraw_amount = Web3.to_wei(1, "ether")
eip1271_contract_wallet.withdraw(withdraw_amount, sender=deployer_account)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
)
assert condition_result is False
assert call_result != eth_amount
assert call_result == (eth_amount - withdraw_amount)

View File

@ -1,7 +1,6 @@
from collections import defaultdict
import pytest
from web3 import Web3
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.lingo import (
@ -10,7 +9,8 @@ from nucypher.policy.conditions.lingo import (
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
from nucypher.policy.conditions.time import TimeCondition
from nucypher.policy.conditions.utils import ConditionProviderManager
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.utils.policy import make_message_kits
@ -62,7 +62,7 @@ def conditions(bob, multichain_ids):
def test_single_retrieve_with_multichain_conditions(
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
enacted_policy, bob, multichain_ursulas, conditions, monkeymodule, testerchain
):
bob.remember_node(multichain_ursulas[0])
bob.start_learning_loop()
@ -72,6 +72,11 @@ def test_single_retrieve_with_multichain_conditions(
encrypted_treasure_map=enacted_policy.treasure_map,
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
monkeymodule.setattr(
ConditionProviderManager,
"web3_endpoints",
lambda *args, **kwargs: [testerchain.w3],
)
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
@ -93,43 +98,30 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
def _mock_configure_provider(*args, **kwargs):
rpc_call_type = args[0]
if isinstance(rpc_call_type, TimeRPCCall):
# time condition call - only RPCCall is made faulty
return testerchain.w3
monkeymodule.setattr(
ConditionProviderManager,
"web3_endpoints",
lambda *args, **kwargs: [testerchain.w3, testerchain.w3],
) # a base, and fallback
# rpc condition call
provider = args[1]
w3 = Web3(provider)
return w3
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
calls = defaultdict(int)
rpc_calls = defaultdict(int)
original_execute_call = RPCCall._execute
def faulty_execute_call(*args, **kwargs):
def faulty_rpc_execute_call(*args, **kwargs):
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
nonlocal calls
nonlocal rpc_calls
rpc_call_object = args[0]
resolved_parameters = args[2]
calls[rpc_call_object.chain] += 1
if calls[rpc_call_object.chain] % 2 == 0:
rpc_calls[rpc_call_object.chain] += 1
if rpc_calls[rpc_call_object.chain] % 2 == 0:
# simulate a network error
raise ConnectionError("Something went wrong with the network")
# replace w3 object with fake provider, with proper w3 object for actual execution
return original_execute_call(
rpc_call_object, testerchain.w3, resolved_parameters
)
RPCCall._execute = faulty_execute_call
# make original call
return original_execute_call(*args, **kwargs)
monkeymodule.setattr(RPCCall, "_execute", faulty_rpc_execute_call)
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
**policy_info_kwargs,
)
assert cleartexts == messages
RPCCall._execute = original_execute_call

View File

@ -14,7 +14,6 @@ from nucypher.blockchain.eth.agents import (
)
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
from nucypher.policy.conditions.evm import RPCCall
from nucypher.utilities.logging import Logger
from tests.constants import (
BONUS_TOKENS_FOR_TESTS,
@ -418,14 +417,6 @@ def taco_child_application_agent(testerchain, test_registry):
# Conditions
#
@pytest.fixture(scope="module")
def mock_rpc_condition(testerchain, monkeymodule):
def configure_mock(*args, **kwargs):
return testerchain.w3
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
@pytest.fixture(scope="module")
def multichain_ids(module_mocker):
ids = mock_permitted_multichain_connections(mocker=module_mocker)
@ -433,7 +424,7 @@ def multichain_ids(module_mocker):
@pytest.fixture(scope="module")
def multichain_ursulas(ursulas, multichain_ids, mock_rpc_condition):
def multichain_ursulas(ursulas, multichain_ids):
setup_multichain_ursulas(ursulas=ursulas, chain_ids=multichain_ids)
return ursulas

View File

@ -0,0 +1,33 @@
import "@openzeppelin/contracts/access/Ownable.sol";
import "@openzeppelin/contracts/interfaces/IERC1271.sol";
import "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
contract SmartContractWallet is IERC1271, Ownable {
using ECDSA for bytes32;
uint public balance;
bytes4 internal constant MAGICVALUE = 0x1626ba7e;
bytes4 constant internal INVALID_SIGNATURE = 0xffffffff;
constructor(address _owner) Ownable(_owner) public {}
function deposit() external payable {
balance += msg.value;
}
function withdraw(uint amount) external onlyOwner {
require(amount <= balance, "Amount exceeds balance");
balance -= amount;
payable(owner()).transfer(amount);
}
function isValidSignature(bytes32 _hash, bytes memory _signature) public view override returns (bytes4) {
address signer = _hash.recover(_signature);
if (signer == owner()) {
return MAGICVALUE;
} else {
return INVALID_SIGNATURE;
}
}
}

View File

@ -18,5 +18,5 @@ def test_condition_chains_endpoint_multichain(
):
response = client.get("/condition_chains")
assert response.status_code == 200
expected_payload = {"version": 1.0, "evm": multichain_ids}
expected_payload = {"version": 1.0, "evm": sorted(multichain_ids)}
assert response.get_json() == expected_payload

View File

@ -17,6 +17,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TESTERCHAIN_CHAIN_ID
CHAIN_ID = 137
@ -52,7 +53,7 @@ class FakeExecutionContractCondition(ContractCondition):
def set_execution_return_value(self, value: Any):
self.execution_return_value = value
def execute(self, providers: Dict, **context) -> Any:
def execute(self, providers: ConditionProviderManager, **context) -> Any:
return self.execution_return_value
EXECUTION_CALL_TYPE = FakeRPCCall
@ -125,7 +126,7 @@ def _check_execution_logic(
json.dumps(condition_dict)
)
fake_execution_contract_condition.set_execution_return_value(execution_result)
fake_providers = {CHAIN_ID: {Mock(BaseProvider)}}
fake_providers = ConditionProviderManager({CHAIN_ID: {Mock(BaseProvider)}})
condition_result, call_result = fake_execution_contract_condition.verify(
fake_providers, **context
)

View File

@ -1,16 +1,32 @@
import maya
import pytest
from eth_account import Account
from eth_account.messages import defunct_hash_message
from hexbytes import HexBytes
from siwe import SiweMessage
from web3.contract import Contract
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.policy.conditions.auth.evm import EIP712Auth, EIP4361Auth, EvmAuth
from nucypher.policy.conditions.auth.evm import (
EIP712Auth,
EIP1271Auth,
EIP4361Auth,
EvmAuth,
)
from nucypher.policy.conditions.exceptions import NoConnectionToChain
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TESTERCHAIN_CHAIN_ID
def test_auth_scheme():
expected_schemes = {
EvmAuth.AuthScheme.EIP712: EIP712Auth,
EvmAuth.AuthScheme.EIP4361: EIP4361Auth,
EvmAuth.AuthScheme.EIP1271: EIP1271Auth,
}
for scheme in EvmAuth.AuthScheme:
expected_scheme = (
EIP712Auth if scheme == EvmAuth.AuthScheme.EIP712 else EIP4361Auth
)
expected_scheme = expected_schemes.get(scheme)
assert EvmAuth.from_scheme(scheme=scheme.value) == expected_scheme
# non-existent scheme
@ -279,3 +295,119 @@ def test_authenticate_eip4361(get_random_checksum_address):
not_stale_but_past_expiry_signature.hex(),
valid_address_for_signature,
)
def test_authenticate_eip1271(mocker, get_random_checksum_address):
# smart contract wallet
eip1271_mock_contract = mocker.Mock(spec=Contract)
contract_address = get_random_checksum_address()
eip1271_mock_contract.address = contract_address
# signer for wallet
data = f"I'm the owner of the smart contract wallet address {eip1271_mock_contract.address}"
wallet_signer = InMemorySigner()
valid_message_signature = wallet_signer.sign_message(
account=wallet_signer.accounts[0], message=data.encode()
)
data_hash = defunct_hash_message(text=data)
typedData = {"chain": TESTERCHAIN_CHAIN_ID, "dataHash": data_hash.hex()}
def _isValidSignature(data_hash, signature_bytes):
class ContractCall:
def __init__(self, hash, signature):
self.hash = hash
self.signature = signature
def call(self):
recovered_address = Account._recover_hash(
message_hash=self.hash, signature=self.signature
)
if recovered_address == wallet_signer.accounts[0]:
return bytes(HexBytes("0x1626ba7e"))
return bytes(HexBytes("0xffffffff"))
return ContractCall(data_hash, signature_bytes)
eip1271_mock_contract.functions.isValidSignature.side_effect = _isValidSignature
# condition provider manager
providers = mocker.Mock(spec=ConditionProviderManager)
w3 = mocker.Mock()
w3.eth.contract.return_value = eip1271_mock_contract
providers.web3_endpoints.return_value = [w3]
# valid signature
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, providers
)
w3.eth.contract.assert_called_once_with(
address=eip1271_mock_contract.address, abi=EIP1271Auth.EIP1271_ABI
)
# no providers
with pytest.raises(EvmAuth.AuthenticationFailed, match="no endpoints provided"):
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, None
)
# invalid typed data - no chain id
with pytest.raises(EvmAuth.InvalidData):
EIP1271Auth.authenticate(
{
"dataHash": data_hash.hex(),
},
valid_message_signature,
eip1271_mock_contract.address,
providers,
)
# invalid typed data - no data hash
with pytest.raises(EvmAuth.InvalidData):
EIP1271Auth.authenticate(
{
"chainId": TESTERCHAIN_CHAIN_ID,
},
valid_message_signature,
eip1271_mock_contract.address,
providers,
)
# use invalid signer
invalid_signer = InMemorySigner()
invalid_message_signature = invalid_signer.sign_message(
account=invalid_signer.accounts[0], message=data.encode()
)
with pytest.raises(EvmAuth.AuthenticationFailed):
EIP1271Auth.authenticate(
typedData,
invalid_message_signature,
eip1271_mock_contract.address,
providers,
)
# bad w3 instance failed for some reason
w3_bad = mocker.Mock()
w3_bad.eth.contract.side_effect = ValueError("something went wrong")
providers.web3_endpoints.return_value = [w3_bad]
with pytest.raises(EvmAuth.AuthenticationFailed, match="something went wrong"):
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, providers
)
assert w3_bad.eth.contract.call_count == 1, "one call that failed"
# fall back to good w3 instances
providers.web3_endpoints.return_value = [w3_bad, w3_bad, w3]
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, providers
)
assert w3_bad.eth.contract.call_count == 3, "two more calls that failed"
# no connection to chain
providers.web3_endpoints.side_effect = NoConnectionToChain(
chain=TESTERCHAIN_CHAIN_ID
)
with pytest.raises(EvmAuth.AuthenticationFailed, match="No connection to chain ID"):
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, providers
)

View File

@ -12,6 +12,7 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
SequentialAccessControlCondition,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
@pytest.fixture(scope="function")
@ -248,7 +249,9 @@ def test_nested_multi_conditions(mock_conditions):
else_condition=False,
)
result, value = if_then_else_condition.verify(providers={})
result, value = if_then_else_condition.verify(
providers=ConditionProviderManager({})
)
assert result is True
assert value == [[1, 2], [2, 3]] # [[or result], [seq result]]
@ -277,7 +280,9 @@ def test_nested_multi_conditions(mock_conditions):
),
)
result, value = if_then_else_condition.verify(providers={})
result, value = if_then_else_condition.verify(
providers=ConditionProviderManager({})
)
assert result is False
assert value == [[1, 2], [3, 2]] # [[or result], [else if condition result]]

View File

@ -11,6 +11,7 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
SequentialAccessControlCondition,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
@pytest.fixture(scope="function")
@ -173,7 +174,9 @@ def test_sequential_condition(mock_condition_variables):
)
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
result, value = sequential_condition.verify(
providers=ConditionProviderManager({}), **original_context
)
assert result is True
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
# only a copy of the context is modified internally
@ -215,7 +218,9 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
result, value = sequential_condition.verify(
providers=ConditionProviderManager({}), **original_context
)
assert result is True
assert value == [
expected_var_1_value,
@ -238,4 +243,4 @@ def test_sequential_condition_a_call_fails(mock_condition_variables):
)
with pytest.raises(Web3Exception):
_ = sequential_condition.verify(providers={})
_ = sequential_condition.verify(providers=ConditionProviderManager({}))

View File

@ -37,6 +37,7 @@ from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
ConditionEvalError,
ConditionProviderManager,
camel_case_to_snake,
evaluate_condition_lingo,
to_camelcase,
@ -102,7 +103,9 @@ def test_evaluate_condition_eval_returns_false():
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
providers=ConditionProviderManager(
{1: Mock(spec=BaseProvider)}
), # fake provider
context={"key": "value"}, # fake context
)
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
@ -119,10 +122,12 @@ def test_evaluate_condition_eval_returns_true():
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}, # multiple fake provider
providers=ConditionProviderManager(
{
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}
),
context={
"key1": "value1",
"key2": "value2",
@ -166,3 +171,48 @@ def test_camel_case_schema():
reloaded_function = schema.load(output)
assert reloaded_function == {"field_name_with_underscores": f"{value}"}
def test_condition_provider_manager(mocker):
# no condition to chain
with pytest.raises(NoConnectionToChain, match="No connection to chain ID"):
manager = ConditionProviderManager(
providers={2: [mocker.Mock(spec=BaseProvider)]}
)
_ = list(manager.web3_endpoints(chain_id=1))
# invalid provider chain
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
w3 = mocker.Mock()
w3.eth.chain_id = (
1 # make w3 instance created from provider have incorrect chain id
)
with patch.object(manager, "_configure_w3", return_value=w3):
with pytest.raises(
NoConnectionToChain, match="Problematic provider endpoints for chain ID"
):
_ = list(manager.web3_endpoints(chain_id=2))
# valid provider chain
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
with patch.object(manager, "_check_chain_id", return_value=None):
assert len(list(manager.web3_endpoints(chain_id=2))) == 1
# multiple providers
manager = ConditionProviderManager(
providers={2: [mocker.Mock(spec=BaseProvider), mocker.Mock(spec=BaseProvider)]}
)
with patch.object(manager, "_check_chain_id", return_value=None):
w3_instances = list(manager.web3_endpoints(chain_id=2))
assert len(w3_instances) == 2
for w3_instance in w3_instances:
assert w3_instance # actual object returned
assert w3_instance.middleware_onion.get("poa") # poa middleware injected
# specific w3 instances
w3_1 = mocker.Mock()
w3_1.eth.chain_id = 2
w3_2 = mocker.Mock()
w3_2.eth.chain_id = 2
with patch.object(manager, "_configure_w3", side_effect=[w3_1, w3_2]):
assert list(manager.web3_endpoints(chain_id=2)) == [w3_1, w3_2]

View File

@ -11,6 +11,7 @@ from web3 import HTTPProvider
from nucypher.blockchain.eth.signers import InMemorySigner, Signer
from nucypher.characters.lawful import Ursula
from nucypher.config.characters import UrsulaConfiguration
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TESTERCHAIN_CHAIN_ID
from tests.utils.blockchain import ReservedTestAccountManager
@ -176,14 +177,16 @@ def setup_multichain_ursulas(chain_ids: List[int], ursulas: List[Ursula]) -> Non
fallback_blockchain_endpoints = [
base_fallback_uri.format(i) for i in range(len(chain_ids))
]
mocked_condition_providers = {
cid: {HTTPProvider(uri), HTTPProvider(furi)}
for cid, uri, furi in zip(
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
)
}
mocked_condition_providers = ConditionProviderManager(
{
cid: [HTTPProvider(uri), HTTPProvider(furi)]
for cid, uri, furi in zip(
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
)
}
)
for ursula in ursulas:
ursula.condition_providers = mocked_condition_providers
ursula.condition_provider_manager = mocked_condition_providers
MOCK_KNOWN_URSULAS_CACHE = dict()