Move more execution logic into RPCCall/ContractCall objects.

pull/3500/head
derekpierre 2024-05-06 15:31:39 -04:00
parent d3ca9e9261
commit 9fa873630c
No known key found for this signature in database
6 changed files with 92 additions and 123 deletions

View File

@ -11,7 +11,6 @@ from typing import (
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
from marshmallow.validate import OneOf
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
@ -104,6 +103,8 @@ def _validate_chain(chain: int) -> None:
class RPCCall:
LOG = logging.Logger(__name__)
ALLOWED_METHODS = {
# RPC
"eth_getBalance": int,
@ -113,7 +114,7 @@ class RPCCall:
SKIP_VALUES = (None,)
name = fields.Str(required=False)
chain = fields.Int(
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
@ -157,11 +158,79 @@ class RPCCall:
) # bind contract function (only exposes the eth API)
return rpc_function
def execute(self, w3: Web3, **context) -> Any:
"""Execute onchain read and return result."""
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
def _check_chain_id(self, w3: Web3) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = w3.eth.chain_id
if provider_chain != self.chain:
raise NoConnectionToChain(
chain=self.chain,
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}",
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
w3 = self._configure_w3(provider=provider)
self._check_chain_id(w3)
return w3
def _next_endpoint(
self, providers: Dict[int, Set[HTTPProvider]]
) -> Iterator[HTTPProvider]:
"""Yields the next web3 provider to try for a given chain ID"""
try:
rpc_providers = providers[self.chain]
# if there are no entries for the chain ID, there
# is no connection to that chain available.
except KeyError:
raise NoConnectionToChain(chain=self.chain)
if not rpc_providers:
raise NoConnectionToChain(chain=self.chain) # TODO: unreachable?
for provider in rpc_providers:
# Someday, we might make this whole function async, and then we can knock on
# each endpoint here to see if it's alive and only yield it if it is.
yield provider
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
resolved_parameters = resolve_parameter_context_variables(
self.parameters, **context
)
endpoints = self._next_endpoint(providers=providers)
latest_error = ""
for provider in endpoints:
w3 = self._configure_provider(provider)
try:
result = self._execute(w3, resolved_parameters)
break
except RequiredContextVariable:
raise
except Exception as e:
latest_error = f"RPC call '{self.method}' failed: {e}"
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
# Something went wrong. Try the next endpoint.
continue
else:
# Fuck.
raise RPCExecutionFailed(
f"RPC call '{self.method}' failed; latest error - {latest_error}"
)
return result
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
rpc_endpoint_, rpc_method = self.method.split("_", 1)
rpc_function = self._get_web3_py_function(w3, rpc_method)
rpc_result = rpc_function(*resolved_parameters) # RPC read
@ -169,7 +238,6 @@ class RPCCall:
class RPCCondition(AccessControlCondition):
LOG = logging.Logger(__name__)
CONDITION_TYPE = ConditionType.RPC.value
class Schema(RPCCall.Schema):
@ -243,51 +311,6 @@ class RPCCondition(AccessControlCondition):
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
)
def _next_endpoint(
self, providers: Dict[int, Set[HTTPProvider]]
) -> Iterator[HTTPProvider]:
"""Yields the next web3 provider to try for a given chain ID"""
try:
rpc_providers = providers[self.chain]
# if there are no entries for the chain ID, there
# is no connection to that chain available.
except KeyError:
raise NoConnectionToChain(chain=self.chain)
if not rpc_providers:
raise NoConnectionToChain(chain=self.chain) # TODO: unreachable?
for provider in rpc_providers:
# Someday, we might make this whole function async, and then we can knock on
# each endpoint here to see if it's alive and only yield it if it is.
yield provider
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
self.provider = provider
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
def _check_chain_id(self) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = self.w3.eth.chain_id
if provider_chain != self.chain:
raise InvalidCondition(
f"This condition can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}"
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
self.w3 = self._configure_w3(provider=provider)
self._check_chain_id()
return provider
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:
@ -306,26 +329,7 @@ class RPCCondition(AccessControlCondition):
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test
)
endpoints = self._next_endpoint(providers=providers)
latest_error = ""
for provider in endpoints:
self._configure_provider(provider=provider)
try:
result = self.rpc_call.execute(self.w3, **context)
break
except RequiredContextVariable:
raise
except Exception as e:
latest_error = f"RPC call '{self.method}' failed: {e}"
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
# Something went wrong. Try the next endpoint.
continue
else:
# Fuck.
raise RPCExecutionFailed(
f"Contract call '{self.method}' failed; latest error - {latest_error}"
)
result = self.rpc_call.execute(providers=providers, **context)
eval_result = return_value_test.eval(result) # test
return eval_result, result
@ -400,11 +404,9 @@ class ContractCall(RPCCall):
f"Unable to find contract function, '{self.method}', for condition: {e}"
)
def execute(self, w3: Web3, **context) -> Any:
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
resolved_parameters = resolve_parameter_context_variables(
self.parameters, **context
)
self.contract_function.w3 = w3
bound_contract_function = self.contract_function(
*resolved_parameters
) # bind contract function
@ -453,25 +455,10 @@ class ContractCondition(RPCCondition):
return self.rpc_call.contract_address
def _validate_expected_return_type(self) -> None:
output_abi_types = _get_abi_types(self.contract_function.contract_abi[0])
comparator_value = self.return_value_test.value
comparator_index = self.return_value_test.index
index_string = (
f"@index={comparator_index}" if comparator_index is not None else ""
_validate_contract_function_expected_return_type(
contract_function=self.contract_function,
return_value_test=self.return_value_test,
)
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{self.contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)
def __repr__(self) -> str:
r = (
@ -481,10 +468,6 @@ class ContractCondition(RPCCondition):
)
return r
def _configure_provider(self, *args, **kwargs):
super()._configure_provider(*args, **kwargs)
self.contract_function.w3 = self.w3
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:

View File

@ -7,9 +7,9 @@ class InvalidConditionLingo(Exception):
class NoConnectionToChain(RuntimeError):
"""Raised when a node does not have an associated provider for a chain."""
def __init__(self, chain: int):
def __init__(self, chain: int, message: str = None):
self.chain = chain
message = f"No connection to chain ID {chain}"
message = message or f"No connection to chain ID {chain}"
super().__init__(message)

View File

@ -36,7 +36,7 @@ class TimeRPCCall(RPCCall):
def _validate_method(self, method):
return method
def execute(self, w3: Web3, **context) -> Any:
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
# TODO may need to rethink as part of #3051 (multicall work).
latest_block = w3.eth.get_block("latest")

View File

@ -22,7 +22,9 @@ from nucypher.policy.conditions.evm import (
RPCCondition,
)
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
InvalidCondition,
InvalidContextVariableData,
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
@ -87,7 +89,7 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
rpc_condition.rpc_call.chain = new_chain
condition_providers = {new_chain: {testerchain.provider}}
with pytest.raises(
InvalidCondition, match=f"can only be evaluated on chain ID {new_chain}"
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -155,11 +157,7 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
}
}
mocker.patch.object(
rpc_condition, "_check_chain_id", return_value=None
) # skip chain check
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
mocker.patch.object(rpc_condition.rpc_call, "_configure_provider", my_configure_w3)
with pytest.raises(RPCExecutionFailed):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -185,10 +183,7 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
}
}
mocker.patch.object(
rpc_condition, "_check_chain_id", return_value=None
) # skip chain check
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
mocker.patch.object(rpc_condition.rpc_call, "_configure_provider", my_configure_w3)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context

View File

@ -14,7 +14,7 @@ from nucypher.blockchain.eth.agents import (
)
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.evm import RPCCall
from nucypher.utilities.logging import Logger
from tests.constants import (
BONUS_TOKENS_FOR_TESTS,
@ -430,16 +430,11 @@ def taco_child_application_agent(testerchain, test_registry):
#
@pytest.fixture(scope="module")
def mock_rpc_condition(module_mocker, testerchain, monkeymodule):
def configure_mock(condition, provider, *args, **kwargs):
condition.provider = provider
def mock_rpc_condition(testerchain, monkeymodule):
def configure_mock(*args, **kwargs):
return testerchain.w3
monkeymodule.setattr(RPCCondition, "_configure_w3", configure_mock)
configure_spy = module_mocker.spy(RPCCondition, "_configure_w3")
chain_id_check_mock = module_mocker.patch.object(RPCCondition, "_check_chain_id")
return configure_spy, chain_id_check_mock
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
@pytest.fixture(scope="module")

View File

@ -9,7 +9,6 @@ from unittest.mock import Mock
import pytest
from hexbytes import HexBytes
from marshmallow import post_load
from web3 import Web3
from web3.providers import BaseProvider
from nucypher.policy.conditions.evm import ContractCall, ContractCondition
@ -53,7 +52,7 @@ class FakeExecutionContractCondition(ContractCondition):
def set_execution_return_value(self, value: Any):
self.execution_return_value = value
def execute(self, w3: Web3, **context) -> Any:
def execute(self, providers: Dict, **context) -> Any:
return self.execution_return_value
class Schema(ContractCondition.Schema):
@ -70,9 +69,6 @@ class FakeExecutionContractCondition(ContractCondition):
def set_execution_return_value(self, value: Any):
self.rpc_call.set_execution_return_value(value)
def _configure_provider(self, provider: BaseProvider):
self.w3 = dict() # doesn't matter what it is
@pytest.fixture(scope="function")
def contract_condition_dict():