use pre-constructed providers, not uris.

pull/3185/head
Kieran Prasch 2023-08-17 13:53:53 +02:00
parent 9c5038c93b
commit 477a586b47
2 changed files with 42 additions and 4 deletions

View File

@ -170,7 +170,7 @@ class RPCCondition(AccessControlCondition):
)
return method
def _next_endpoint(self, providers: Dict[int, Set[str]]) -> Iterator[HTTPProvider]:
def _next_endpoint(self, providers: Dict[int, Set[HTTPProvider]]) -> Iterator[HTTPProvider]:
"""Yields the next web3 provider to try for a given chain ID"""
try:
rpc_providers = providers[self.chain]
@ -184,7 +184,7 @@ class RPCCondition(AccessControlCondition):
for provider in rpc_providers:
# Someday, we might make this whole function async, and then we can knock on
# each endpoint here to see if it's alive and only yield it if it is.
yield HTTPProvider(endpoint_uri=provider)
yield provider
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance

View File

@ -1,4 +1,7 @@
from collections import defaultdict
import pytest
from web3 import HTTPProvider
from nucypher.policy.conditions.evm import _CONDITION_CHAINS, RPCCondition
from nucypher.policy.conditions.lingo import ConditionLingo
@ -57,7 +60,7 @@ def multichain_ursulas(ursulas, chain_ids):
base_uri = "tester://multichain.{}"
provider_uris = [base_uri.format(i) for i in range(len(chain_ids))]
mocked_condition_providers = {
cid: {uri} for cid, uri in zip(chain_ids, provider_uris)
cid: {HTTPProvider(uri)} for cid, uri in zip(chain_ids, provider_uris)
}
for ursula in ursulas:
ursula.condition_providers = mocked_condition_providers
@ -81,7 +84,7 @@ def mock_rpc_condition(module_mocker, testerchain):
def test_single_retrieve_with_multichain_conditions(
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition, mocker
):
bob.remember_node(multichain_ursulas[0])
bob.start_learning_loop()
@ -98,3 +101,38 @@ def test_single_retrieve_with_multichain_conditions(
)
assert cleartexts == messages
def test_single_decryption_request_with_faulty_rpc_endpoint(
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
):
bob.remember_node(multichain_ursulas[0])
bob.start_learning_loop()
messages, message_kits = make_message_kits(enacted_policy.public_key, conditions)
policy_info_kwargs = dict(
encrypted_treasure_map=enacted_policy.treasure_map,
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
calls = defaultdict(int)
original_execute_call = RPCCondition._execute_call
def faulty_execute_call(*args, **kwargs):
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
nonlocal calls
rpc_call = args[0]
calls[rpc_call.chain] += 1
if calls[rpc_call.chain] == 5:
raise Exception("Something went wrong")
return original_execute_call(*args, **kwargs)
RPCCondition._execute_call = faulty_execute_call
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
**policy_info_kwargs,
)
assert cleartexts == messages
RPCCondition._execute_call = original_execute_call