mirror of https://github.com/nucypher/nucypher.git
Adjust tests to use ConditionProviderManager instead of the raw dictionary of providers.
Also adjust mocking since RPCCall no longer does the configuration of w3 instances; instead the ConditionProviderManager does it.pull/3576/head
parent
24d0669940
commit
dcd870d49e
|
@ -11,12 +11,15 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def condition_providers(testerchain):
|
||||
providers = {testerchain.client.chain_id: {testerchain.provider}}
|
||||
providers = ConditionProviderManager(
|
||||
{testerchain.client.chain_id: {testerchain.provider}}
|
||||
)
|
||||
return providers
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
@ -33,6 +33,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
NotCompoundCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import (
|
||||
TEST_ETH_PROVIDER_URI,
|
||||
TEST_POLYGON_PROVIDER_URI,
|
||||
|
@ -67,11 +68,12 @@ def test_rpc_condition_evaluation_no_providers(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
_ = rpc_condition.verify(providers={}, **context)
|
||||
_ = rpc_condition.verify(providers=ConditionProviderManager({}), **context)
|
||||
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
_ = rpc_condition.verify(
|
||||
providers={testerchain.client.chain_id: set()}, **context
|
||||
providers=ConditionProviderManager({testerchain.client.chain_id: list()}),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
|
@ -85,9 +87,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
|
|||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
new_chain = 23
|
||||
rpc_condition.execution_call.chain = new_chain
|
||||
condition_providers = {new_chain: {testerchain.provider}}
|
||||
condition_providers = ConditionProviderManager({new_chain: [testerchain.provider]})
|
||||
with pytest.raises(
|
||||
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
NoConnectionToChain,
|
||||
match=f"Problematic provider connections for chain ID {new_chain}",
|
||||
):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -118,13 +121,15 @@ def test_rpc_condition_evaluation_multiple_chain_providers(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
condition_providers = {
|
||||
"1": {"fake1a", "fake1b"},
|
||||
"2": {"fake2"},
|
||||
"3": {"fake3"},
|
||||
"4": {"fake4"},
|
||||
TESTERCHAIN_CHAIN_ID: {testerchain.provider},
|
||||
}
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
"1": ["fake1a", "fake1b"],
|
||||
"2": ["fake2"],
|
||||
"3": ["fake3"],
|
||||
"4": ["fake4"],
|
||||
TESTERCHAIN_CHAIN_ID: [testerchain.provider],
|
||||
}
|
||||
)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
|
@ -144,20 +149,17 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
def my_configure_w3(provider: BaseProvider):
|
||||
return Web3(provider)
|
||||
|
||||
condition_providers = {
|
||||
TESTERCHAIN_CHAIN_ID: {
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
TESTERCHAIN_CHAIN_ID: [
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
|
||||
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
|
||||
with pytest.raises(RPCExecutionFailed):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -171,22 +173,19 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
def my_configure_w3(provider: BaseProvider):
|
||||
return Web3(provider)
|
||||
|
||||
condition_providers = {
|
||||
TESTERCHAIN_CHAIN_ID: {
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
testerchain.provider,
|
||||
condition_providers = ConditionProviderManager(
|
||||
{
|
||||
TESTERCHAIN_CHAIN_ID: [
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
mocker.Mock(spec=BaseProvider),
|
||||
testerchain.provider,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
|
||||
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
)
|
||||
|
@ -208,10 +207,12 @@ def test_rpc_condition_evaluation_no_connection_to_chain(
|
|||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
|
||||
# condition providers for other unrelated chains
|
||||
providers = {
|
||||
1: mock.Mock(), # mainnet
|
||||
11155111: mock.Mock(), # Sepolia
|
||||
}
|
||||
providers = ConditionProviderManager(
|
||||
{
|
||||
1: [mock.Mock()], # mainnet
|
||||
11155111: [mock.Mock()], # Sepolia
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(NoConnectionToChain):
|
||||
rpc_condition.verify(providers=providers, **context)
|
||||
|
@ -250,7 +251,10 @@ def test_rpc_condition_evaluation_with_context_var_in_return_value_test(
|
|||
invalid_balance = balance + 1
|
||||
context[":balanceContextVar"] = invalid_balance
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers={testerchain.client.chain_id: [testerchain.provider]}, **context
|
||||
providers=ConditionProviderManager(
|
||||
{testerchain.client.chain_id: [testerchain.provider]}
|
||||
),
|
||||
**context,
|
||||
)
|
||||
assert condition_result is False
|
||||
assert call_result != invalid_balance
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
|
@ -10,7 +9,8 @@ from nucypher.policy.conditions.lingo import (
|
|||
ConditionType,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
|
||||
from nucypher.policy.conditions.time import TimeCondition
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from nucypher.utilities.logging import GlobalLoggerSettings
|
||||
from tests.utils.policy import make_message_kits
|
||||
|
||||
|
@ -62,7 +62,7 @@ def conditions(bob, multichain_ids):
|
|||
|
||||
|
||||
def test_single_retrieve_with_multichain_conditions(
|
||||
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
|
||||
enacted_policy, bob, multichain_ursulas, conditions, monkeymodule, testerchain
|
||||
):
|
||||
bob.remember_node(multichain_ursulas[0])
|
||||
bob.start_learning_loop()
|
||||
|
@ -72,6 +72,11 @@ def test_single_retrieve_with_multichain_conditions(
|
|||
encrypted_treasure_map=enacted_policy.treasure_map,
|
||||
alice_verifying_key=enacted_policy.publisher_verifying_key,
|
||||
)
|
||||
monkeymodule.setattr(
|
||||
ConditionProviderManager,
|
||||
"web3_endpoints",
|
||||
lambda *args, **kwargs: [testerchain.w3],
|
||||
)
|
||||
|
||||
cleartexts = bob.retrieve_and_decrypt(
|
||||
message_kits=message_kits,
|
||||
|
@ -93,43 +98,30 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
|
|||
alice_verifying_key=enacted_policy.publisher_verifying_key,
|
||||
)
|
||||
|
||||
def _mock_configure_provider(*args, **kwargs):
|
||||
rpc_call_type = args[0]
|
||||
if isinstance(rpc_call_type, TimeRPCCall):
|
||||
# time condition call - only RPCCall is made faulty
|
||||
return testerchain.w3
|
||||
monkeymodule.setattr(
|
||||
ConditionProviderManager,
|
||||
"web3_endpoints",
|
||||
lambda *args, **kwargs: [testerchain.w3, testerchain.w3],
|
||||
) # a base, and fallback
|
||||
|
||||
# rpc condition call
|
||||
provider = args[1]
|
||||
w3 = Web3(provider)
|
||||
return w3
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
|
||||
|
||||
calls = defaultdict(int)
|
||||
rpc_calls = defaultdict(int)
|
||||
original_execute_call = RPCCall._execute
|
||||
|
||||
def faulty_execute_call(*args, **kwargs):
|
||||
def faulty_rpc_execute_call(*args, **kwargs):
|
||||
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
|
||||
nonlocal calls
|
||||
nonlocal rpc_calls
|
||||
rpc_call_object = args[0]
|
||||
resolved_parameters = args[2]
|
||||
calls[rpc_call_object.chain] += 1
|
||||
if calls[rpc_call_object.chain] % 2 == 0:
|
||||
rpc_calls[rpc_call_object.chain] += 1
|
||||
if rpc_calls[rpc_call_object.chain] % 2 == 0:
|
||||
# simulate a network error
|
||||
raise ConnectionError("Something went wrong with the network")
|
||||
|
||||
# replace w3 object with fake provider, with proper w3 object for actual execution
|
||||
return original_execute_call(
|
||||
rpc_call_object, testerchain.w3, resolved_parameters
|
||||
)
|
||||
|
||||
RPCCall._execute = faulty_execute_call
|
||||
# make original call
|
||||
return original_execute_call(*args, **kwargs)
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_execute", faulty_rpc_execute_call)
|
||||
cleartexts = bob.retrieve_and_decrypt(
|
||||
message_kits=message_kits,
|
||||
**policy_info_kwargs,
|
||||
)
|
||||
assert cleartexts == messages
|
||||
|
||||
RPCCall._execute = original_execute_call
|
||||
|
|
|
@ -14,7 +14,6 @@ from nucypher.blockchain.eth.agents import (
|
|||
)
|
||||
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
|
||||
from nucypher.policy.conditions.evm import RPCCall
|
||||
from nucypher.utilities.logging import Logger
|
||||
from tests.constants import (
|
||||
BONUS_TOKENS_FOR_TESTS,
|
||||
|
@ -418,14 +417,6 @@ def taco_child_application_agent(testerchain, test_registry):
|
|||
# Conditions
|
||||
#
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_rpc_condition(testerchain, monkeymodule):
|
||||
def configure_mock(*args, **kwargs):
|
||||
return testerchain.w3
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def multichain_ids(module_mocker):
|
||||
ids = mock_permitted_multichain_connections(mocker=module_mocker)
|
||||
|
@ -433,7 +424,7 @@ def multichain_ids(module_mocker):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def multichain_ursulas(ursulas, multichain_ids, mock_rpc_condition):
|
||||
def multichain_ursulas(ursulas, multichain_ids):
|
||||
setup_multichain_ursulas(ursulas=ursulas, chain_ids=multichain_ids)
|
||||
return ursulas
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
CHAIN_ID = 137
|
||||
|
@ -52,7 +53,7 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_return_value = value
|
||||
|
||||
def execute(self, providers: Dict, **context) -> Any:
|
||||
def execute(self, providers: ConditionProviderManager, **context) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
EXECUTION_CALL_TYPE = FakeRPCCall
|
||||
|
@ -125,7 +126,7 @@ def _check_execution_logic(
|
|||
json.dumps(condition_dict)
|
||||
)
|
||||
fake_execution_contract_condition.set_execution_return_value(execution_result)
|
||||
fake_providers = {CHAIN_ID: {Mock(BaseProvider)}}
|
||||
fake_providers = ConditionProviderManager({CHAIN_ID: {Mock(BaseProvider)}})
|
||||
condition_result, call_result = fake_execution_contract_condition.verify(
|
||||
fake_providers, **context
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -248,7 +249,9 @@ def test_nested_multi_conditions(mock_conditions):
|
|||
else_condition=False,
|
||||
)
|
||||
|
||||
result, value = if_then_else_condition.verify(providers={})
|
||||
result, value = if_then_else_condition.verify(
|
||||
providers=ConditionProviderManager({})
|
||||
)
|
||||
assert result is True
|
||||
assert value == [[1, 2], [2, 3]] # [[or result], [seq result]]
|
||||
|
||||
|
@ -277,7 +280,9 @@ def test_nested_multi_conditions(mock_conditions):
|
|||
),
|
||||
)
|
||||
|
||||
result, value = if_then_else_condition.verify(providers={})
|
||||
result, value = if_then_else_condition.verify(
|
||||
providers=ConditionProviderManager({})
|
||||
)
|
||||
assert result is False
|
||||
assert value == [[1, 2], [3, 2]] # [[or result], [else if condition result]]
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from nucypher.policy.conditions.lingo import (
|
|||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -173,7 +174,9 @@ def test_sequential_condition(mock_condition_variables):
|
|||
)
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
result, value = sequential_condition.verify(
|
||||
providers=ConditionProviderManager({}), **original_context
|
||||
)
|
||||
assert result is True
|
||||
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
|
||||
# only a copy of the context is modified internally
|
||||
|
@ -215,7 +218,9 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
|
|||
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
result, value = sequential_condition.verify(
|
||||
providers=ConditionProviderManager({}), **original_context
|
||||
)
|
||||
assert result is True
|
||||
assert value == [
|
||||
expected_var_1_value,
|
||||
|
@ -238,4 +243,4 @@ def test_sequential_condition_a_call_fails(mock_condition_variables):
|
|||
)
|
||||
|
||||
with pytest.raises(Web3Exception):
|
||||
_ = sequential_condition.verify(providers={})
|
||||
_ = sequential_condition.verify(providers=ConditionProviderManager({}))
|
||||
|
|
|
@ -37,6 +37,7 @@ from nucypher.policy.conditions.lingo import ConditionLingo
|
|||
from nucypher.policy.conditions.utils import (
|
||||
CamelCaseSchema,
|
||||
ConditionEvalError,
|
||||
ConditionProviderManager,
|
||||
camel_case_to_snake,
|
||||
evaluate_condition_lingo,
|
||||
to_camelcase,
|
||||
|
@ -102,7 +103,9 @@ def test_evaluate_condition_eval_returns_false():
|
|||
with pytest.raises(ConditionEvalError) as eval_error:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
providers=ConditionProviderManager(
|
||||
{1: Mock(spec=BaseProvider)}
|
||||
), # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
|
||||
|
@ -119,10 +122,12 @@ def test_evaluate_condition_eval_returns_true():
|
|||
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}, # multiple fake provider
|
||||
providers=ConditionProviderManager(
|
||||
{
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}
|
||||
),
|
||||
context={
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
|
|
|
@ -11,6 +11,7 @@ from web3 import HTTPProvider
|
|||
from nucypher.blockchain.eth.signers import InMemorySigner, Signer
|
||||
from nucypher.characters.lawful import Ursula
|
||||
from nucypher.config.characters import UrsulaConfiguration
|
||||
from nucypher.policy.conditions.utils import ConditionProviderManager
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
from tests.utils.blockchain import ReservedTestAccountManager
|
||||
|
||||
|
@ -176,12 +177,14 @@ def setup_multichain_ursulas(chain_ids: List[int], ursulas: List[Ursula]) -> Non
|
|||
fallback_blockchain_endpoints = [
|
||||
base_fallback_uri.format(i) for i in range(len(chain_ids))
|
||||
]
|
||||
mocked_condition_providers = {
|
||||
cid: {HTTPProvider(uri), HTTPProvider(furi)}
|
||||
for cid, uri, furi in zip(
|
||||
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
|
||||
)
|
||||
}
|
||||
mocked_condition_providers = ConditionProviderManager(
|
||||
{
|
||||
cid: [HTTPProvider(uri), HTTPProvider(furi)]
|
||||
for cid, uri, furi in zip(
|
||||
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
|
||||
)
|
||||
}
|
||||
)
|
||||
for ursula in ursulas:
|
||||
ursula.condition_providers = mocked_condition_providers
|
||||
|
||||
|
|
Loading…
Reference in New Issue