Adjust tests to use ConditionProviderManager instead of the raw dictionary of providers.

Also adjust mocking since RPCCall no longer does the configuration of w3 instances; instead the ConditionProviderManager does it.
pull/3576/head
derekpierre 2025-01-24 14:37:16 -05:00
parent 24d0669940
commit dcd870d49e
No known key found for this signature in database
9 changed files with 108 additions and 99 deletions

View File

@ -11,12 +11,15 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
@pytest.fixture()
def condition_providers(testerchain):
providers = {testerchain.client.chain_id: {testerchain.provider}}
providers = ConditionProviderManager(
{testerchain.client.chain_id: {testerchain.provider}}
)
return providers
@pytest.fixture()

View File

@ -33,6 +33,7 @@ from nucypher.policy.conditions.lingo import (
NotCompoundCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import (
TEST_ETH_PROVIDER_URI,
TEST_POLYGON_PROVIDER_URI,
@ -67,11 +68,12 @@ def test_rpc_condition_evaluation_no_providers(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
with pytest.raises(NoConnectionToChain):
_ = rpc_condition.verify(providers={}, **context)
_ = rpc_condition.verify(providers=ConditionProviderManager({}), **context)
with pytest.raises(NoConnectionToChain):
_ = rpc_condition.verify(
providers={testerchain.client.chain_id: set()}, **context
providers=ConditionProviderManager({testerchain.client.chain_id: list()}),
**context,
)
@ -85,9 +87,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
new_chain = 23
rpc_condition.execution_call.chain = new_chain
condition_providers = {new_chain: {testerchain.provider}}
condition_providers = ConditionProviderManager({new_chain: [testerchain.provider]})
with pytest.raises(
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
NoConnectionToChain,
match=f"Problematic provider connections for chain ID {new_chain}",
):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -118,13 +121,15 @@ def test_rpc_condition_evaluation_multiple_chain_providers(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
condition_providers = {
"1": {"fake1a", "fake1b"},
"2": {"fake2"},
"3": {"fake3"},
"4": {"fake4"},
TESTERCHAIN_CHAIN_ID: {testerchain.provider},
}
condition_providers = ConditionProviderManager(
{
"1": ["fake1a", "fake1b"],
"2": ["fake2"],
"3": ["fake3"],
"4": ["fake4"],
TESTERCHAIN_CHAIN_ID: [testerchain.provider],
}
)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
@ -144,20 +149,17 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
def my_configure_w3(provider: BaseProvider):
return Web3(provider)
condition_providers = {
TESTERCHAIN_CHAIN_ID: {
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
condition_providers = ConditionProviderManager(
{
TESTERCHAIN_CHAIN_ID: [
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
]
}
}
mocker.patch.object(
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
with pytest.raises(RPCExecutionFailed):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -171,22 +173,19 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
def my_configure_w3(provider: BaseProvider):
return Web3(provider)
condition_providers = {
TESTERCHAIN_CHAIN_ID: {
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
testerchain.provider,
condition_providers = ConditionProviderManager(
{
TESTERCHAIN_CHAIN_ID: [
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
mocker.Mock(spec=BaseProvider),
testerchain.provider,
]
}
}
mocker.patch.object(
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
mocker.patch.object(condition_providers, "_check_chain_id", return_value=None)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context
)
@ -208,10 +207,12 @@ def test_rpc_condition_evaluation_no_connection_to_chain(
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
# condition providers for other unrelated chains
providers = {
1: mock.Mock(), # mainnet
11155111: mock.Mock(), # Sepolia
}
providers = ConditionProviderManager(
{
1: [mock.Mock()], # mainnet
11155111: [mock.Mock()], # Sepolia
}
)
with pytest.raises(NoConnectionToChain):
rpc_condition.verify(providers=providers, **context)
@ -250,7 +251,10 @@ def test_rpc_condition_evaluation_with_context_var_in_return_value_test(
invalid_balance = balance + 1
context[":balanceContextVar"] = invalid_balance
condition_result, call_result = rpc_condition.verify(
providers={testerchain.client.chain_id: [testerchain.provider]}, **context
providers=ConditionProviderManager(
{testerchain.client.chain_id: [testerchain.provider]}
),
**context,
)
assert condition_result is False
assert call_result != invalid_balance

View File

@ -1,7 +1,6 @@
from collections import defaultdict
import pytest
from web3 import Web3
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.lingo import (
@ -10,7 +9,8 @@ from nucypher.policy.conditions.lingo import (
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
from nucypher.policy.conditions.time import TimeCondition
from nucypher.policy.conditions.utils import ConditionProviderManager
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.utils.policy import make_message_kits
@ -62,7 +62,7 @@ def conditions(bob, multichain_ids):
def test_single_retrieve_with_multichain_conditions(
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
enacted_policy, bob, multichain_ursulas, conditions, monkeymodule, testerchain
):
bob.remember_node(multichain_ursulas[0])
bob.start_learning_loop()
@ -72,6 +72,11 @@ def test_single_retrieve_with_multichain_conditions(
encrypted_treasure_map=enacted_policy.treasure_map,
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
monkeymodule.setattr(
ConditionProviderManager,
"web3_endpoints",
lambda *args, **kwargs: [testerchain.w3],
)
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
@ -93,43 +98,30 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
def _mock_configure_provider(*args, **kwargs):
rpc_call_type = args[0]
if isinstance(rpc_call_type, TimeRPCCall):
# time condition call - only RPCCall is made faulty
return testerchain.w3
monkeymodule.setattr(
ConditionProviderManager,
"web3_endpoints",
lambda *args, **kwargs: [testerchain.w3, testerchain.w3],
) # a base, and fallback
# rpc condition call
provider = args[1]
w3 = Web3(provider)
return w3
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
calls = defaultdict(int)
rpc_calls = defaultdict(int)
original_execute_call = RPCCall._execute
def faulty_execute_call(*args, **kwargs):
def faulty_rpc_execute_call(*args, **kwargs):
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
nonlocal calls
nonlocal rpc_calls
rpc_call_object = args[0]
resolved_parameters = args[2]
calls[rpc_call_object.chain] += 1
if calls[rpc_call_object.chain] % 2 == 0:
rpc_calls[rpc_call_object.chain] += 1
if rpc_calls[rpc_call_object.chain] % 2 == 0:
# simulate a network error
raise ConnectionError("Something went wrong with the network")
# replace w3 object with fake provider, with proper w3 object for actual execution
return original_execute_call(
rpc_call_object, testerchain.w3, resolved_parameters
)
RPCCall._execute = faulty_execute_call
# make original call
return original_execute_call(*args, **kwargs)
monkeymodule.setattr(RPCCall, "_execute", faulty_rpc_execute_call)
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
**policy_info_kwargs,
)
assert cleartexts == messages
RPCCall._execute = original_execute_call

View File

@ -14,7 +14,6 @@ from nucypher.blockchain.eth.agents import (
)
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
from nucypher.policy.conditions.evm import RPCCall
from nucypher.utilities.logging import Logger
from tests.constants import (
BONUS_TOKENS_FOR_TESTS,
@ -418,14 +417,6 @@ def taco_child_application_agent(testerchain, test_registry):
# Conditions
#
@pytest.fixture(scope="module")
def mock_rpc_condition(testerchain, monkeymodule):
def configure_mock(*args, **kwargs):
return testerchain.w3
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
@pytest.fixture(scope="module")
def multichain_ids(module_mocker):
ids = mock_permitted_multichain_connections(mocker=module_mocker)
@ -433,7 +424,7 @@ def multichain_ids(module_mocker):
@pytest.fixture(scope="module")
def multichain_ursulas(ursulas, multichain_ids, mock_rpc_condition):
def multichain_ursulas(ursulas, multichain_ids):
setup_multichain_ursulas(ursulas=ursulas, chain_ids=multichain_ids)
return ursulas

View File

@ -17,6 +17,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TESTERCHAIN_CHAIN_ID
CHAIN_ID = 137
@ -52,7 +53,7 @@ class FakeExecutionContractCondition(ContractCondition):
def set_execution_return_value(self, value: Any):
self.execution_return_value = value
def execute(self, providers: Dict, **context) -> Any:
def execute(self, providers: ConditionProviderManager, **context) -> Any:
return self.execution_return_value
EXECUTION_CALL_TYPE = FakeRPCCall
@ -125,7 +126,7 @@ def _check_execution_logic(
json.dumps(condition_dict)
)
fake_execution_contract_condition.set_execution_return_value(execution_result)
fake_providers = {CHAIN_ID: {Mock(BaseProvider)}}
fake_providers = ConditionProviderManager({CHAIN_ID: {Mock(BaseProvider)}})
condition_result, call_result = fake_execution_contract_condition.verify(
fake_providers, **context
)

View File

@ -12,6 +12,7 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
SequentialAccessControlCondition,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
@pytest.fixture(scope="function")
@ -248,7 +249,9 @@ def test_nested_multi_conditions(mock_conditions):
else_condition=False,
)
result, value = if_then_else_condition.verify(providers={})
result, value = if_then_else_condition.verify(
providers=ConditionProviderManager({})
)
assert result is True
assert value == [[1, 2], [2, 3]] # [[or result], [seq result]]
@ -277,7 +280,9 @@ def test_nested_multi_conditions(mock_conditions):
),
)
result, value = if_then_else_condition.verify(providers={})
result, value = if_then_else_condition.verify(
providers=ConditionProviderManager({})
)
assert result is False
assert value == [[1, 2], [3, 2]] # [[or result], [else if condition result]]

View File

@ -11,6 +11,7 @@ from nucypher.policy.conditions.lingo import (
OrCompoundCondition,
SequentialAccessControlCondition,
)
from nucypher.policy.conditions.utils import ConditionProviderManager
@pytest.fixture(scope="function")
@ -173,7 +174,9 @@ def test_sequential_condition(mock_condition_variables):
)
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
result, value = sequential_condition.verify(
providers=ConditionProviderManager({}), **original_context
)
assert result is True
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
# only a copy of the context is modified internally
@ -215,7 +218,9 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
result, value = sequential_condition.verify(
providers=ConditionProviderManager({}), **original_context
)
assert result is True
assert value == [
expected_var_1_value,
@ -238,4 +243,4 @@ def test_sequential_condition_a_call_fails(mock_condition_variables):
)
with pytest.raises(Web3Exception):
_ = sequential_condition.verify(providers={})
_ = sequential_condition.verify(providers=ConditionProviderManager({}))

View File

@ -37,6 +37,7 @@ from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
ConditionEvalError,
ConditionProviderManager,
camel_case_to_snake,
evaluate_condition_lingo,
to_camelcase,
@ -102,7 +103,9 @@ def test_evaluate_condition_eval_returns_false():
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
providers=ConditionProviderManager(
{1: Mock(spec=BaseProvider)}
), # fake provider
context={"key": "value"}, # fake context
)
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
@ -119,10 +122,12 @@ def test_evaluate_condition_eval_returns_true():
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}, # multiple fake provider
providers=ConditionProviderManager(
{
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}
),
context={
"key1": "value1",
"key2": "value2",

View File

@ -11,6 +11,7 @@ from web3 import HTTPProvider
from nucypher.blockchain.eth.signers import InMemorySigner, Signer
from nucypher.characters.lawful import Ursula
from nucypher.config.characters import UrsulaConfiguration
from nucypher.policy.conditions.utils import ConditionProviderManager
from tests.constants import TESTERCHAIN_CHAIN_ID
from tests.utils.blockchain import ReservedTestAccountManager
@ -176,12 +177,14 @@ def setup_multichain_ursulas(chain_ids: List[int], ursulas: List[Ursula]) -> Non
fallback_blockchain_endpoints = [
base_fallback_uri.format(i) for i in range(len(chain_ids))
]
mocked_condition_providers = {
cid: {HTTPProvider(uri), HTTPProvider(furi)}
for cid, uri, furi in zip(
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
)
}
mocked_condition_providers = ConditionProviderManager(
{
cid: [HTTPProvider(uri), HTTPProvider(furi)]
for cid, uri, furi in zip(
chain_ids, blockchain_endpoints, fallback_blockchain_endpoints
)
}
)
for ursula in ursulas:
ursula.condition_providers = mocked_condition_providers