mirror of https://github.com/nucypher/nucypher.git
Move more execution logic into RPCCall/ContractCall objects.
parent
d3ca9e9261
commit
9fa873630c
|
@ -11,7 +11,6 @@ from typing import (
|
|||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
|
||||
from marshmallow.validate import OneOf
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
@ -104,6 +103,8 @@ def _validate_chain(chain: int) -> None:
|
|||
|
||||
|
||||
class RPCCall:
|
||||
LOG = logging.Logger(__name__)
|
||||
|
||||
ALLOWED_METHODS = {
|
||||
# RPC
|
||||
"eth_getBalance": int,
|
||||
|
@ -113,7 +114,7 @@ class RPCCall:
|
|||
SKIP_VALUES = (None,)
|
||||
name = fields.Str(required=False)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
|
||||
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
|
@ -157,11 +158,79 @@ class RPCCall:
|
|||
) # bind contract function (only exposes the eth API)
|
||||
return rpc_function
|
||||
|
||||
def execute(self, w3: Web3, **context) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self, w3: Web3) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise NoConnectionToChain(
|
||||
chain=self.chain,
|
||||
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}",
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(w3)
|
||||
return w3
|
||||
|
||||
def _next_endpoint(
|
||||
self, providers: Dict[int, Set[HTTPProvider]]
|
||||
) -> Iterator[HTTPProvider]:
|
||||
"""Yields the next web3 provider to try for a given chain ID"""
|
||||
try:
|
||||
rpc_providers = providers[self.chain]
|
||||
|
||||
# if there are no entries for the chain ID, there
|
||||
# is no connection to that chain available.
|
||||
except KeyError:
|
||||
raise NoConnectionToChain(chain=self.chain)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=self.chain) # TODO: unreachable?
|
||||
for provider in rpc_providers:
|
||||
# Someday, we might make this whole function async, and then we can knock on
|
||||
# each endpoint here to see if it's alive and only yield it if it is.
|
||||
yield provider
|
||||
|
||||
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
|
||||
resolved_parameters = resolve_parameter_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
latest_error = ""
|
||||
for provider in endpoints:
|
||||
w3 = self._configure_provider(provider)
|
||||
try:
|
||||
result = self._execute(w3, resolved_parameters)
|
||||
break
|
||||
except RequiredContextVariable:
|
||||
raise
|
||||
except Exception as e:
|
||||
latest_error = f"RPC call '{self.method}' failed: {e}"
|
||||
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
|
||||
# Something went wrong. Try the next endpoint.
|
||||
continue
|
||||
else:
|
||||
# Fuck.
|
||||
raise RPCExecutionFailed(
|
||||
f"RPC call '{self.method}' failed; latest error - {latest_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
rpc_endpoint_, rpc_method = self.method.split("_", 1)
|
||||
rpc_function = self._get_web3_py_function(w3, rpc_method)
|
||||
rpc_result = rpc_function(*resolved_parameters) # RPC read
|
||||
|
@ -169,7 +238,6 @@ class RPCCall:
|
|||
|
||||
|
||||
class RPCCondition(AccessControlCondition):
|
||||
LOG = logging.Logger(__name__)
|
||||
CONDITION_TYPE = ConditionType.RPC.value
|
||||
|
||||
class Schema(RPCCall.Schema):
|
||||
|
@ -243,51 +311,6 @@ class RPCCondition(AccessControlCondition):
|
|||
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
|
||||
)
|
||||
|
||||
def _next_endpoint(
|
||||
self, providers: Dict[int, Set[HTTPProvider]]
|
||||
) -> Iterator[HTTPProvider]:
|
||||
"""Yields the next web3 provider to try for a given chain ID"""
|
||||
try:
|
||||
rpc_providers = providers[self.chain]
|
||||
|
||||
# if there are no entries for the chain ID, there
|
||||
# is no connection to that chain available.
|
||||
except KeyError:
|
||||
raise NoConnectionToChain(chain=self.chain)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=self.chain) # TODO: unreachable?
|
||||
for provider in rpc_providers:
|
||||
# Someday, we might make this whole function async, and then we can knock on
|
||||
# each endpoint here to see if it's alive and only yield it if it is.
|
||||
yield provider
|
||||
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
self.provider = provider
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = self.w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise InvalidCondition(
|
||||
f"This condition can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}"
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
self.w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id()
|
||||
return provider
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
|
@ -306,26 +329,7 @@ class RPCCondition(AccessControlCondition):
|
|||
return_value_test = self._align_comparator_value_with_abi(
|
||||
resolved_return_value_test
|
||||
)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
latest_error = ""
|
||||
for provider in endpoints:
|
||||
self._configure_provider(provider=provider)
|
||||
try:
|
||||
result = self.rpc_call.execute(self.w3, **context)
|
||||
break
|
||||
except RequiredContextVariable:
|
||||
raise
|
||||
except Exception as e:
|
||||
latest_error = f"RPC call '{self.method}' failed: {e}"
|
||||
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
|
||||
# Something went wrong. Try the next endpoint.
|
||||
continue
|
||||
else:
|
||||
# Fuck.
|
||||
raise RPCExecutionFailed(
|
||||
f"Contract call '{self.method}' failed; latest error - {latest_error}"
|
||||
)
|
||||
result = self.rpc_call.execute(providers=providers, **context)
|
||||
|
||||
eval_result = return_value_test.eval(result) # test
|
||||
return eval_result, result
|
||||
|
@ -400,11 +404,9 @@ class ContractCall(RPCCall):
|
|||
f"Unable to find contract function, '{self.method}', for condition: {e}"
|
||||
)
|
||||
|
||||
def execute(self, w3: Web3, **context) -> Any:
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
resolved_parameters = resolve_parameter_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
self.contract_function.w3 = w3
|
||||
bound_contract_function = self.contract_function(
|
||||
*resolved_parameters
|
||||
) # bind contract function
|
||||
|
@ -453,25 +455,10 @@ class ContractCondition(RPCCondition):
|
|||
return self.rpc_call.contract_address
|
||||
|
||||
def _validate_expected_return_type(self) -> None:
|
||||
output_abi_types = _get_abi_types(self.contract_function.contract_abi[0])
|
||||
comparator_value = self.return_value_test.value
|
||||
comparator_index = self.return_value_test.index
|
||||
index_string = (
|
||||
f"@index={comparator_index}" if comparator_index is not None else ""
|
||||
_validate_contract_function_expected_return_type(
|
||||
contract_function=self.contract_function,
|
||||
return_value_test=self.return_value_test,
|
||||
)
|
||||
failure_message = (
|
||||
f"Invalid return value comparison type '{type(comparator_value)}' for "
|
||||
f"'{self.contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
|
||||
)
|
||||
|
||||
if len(output_abi_types) == 1:
|
||||
_validate_single_output_type(
|
||||
output_abi_types[0], comparator_value, comparator_index, failure_message
|
||||
)
|
||||
else:
|
||||
_validate_multiple_output_types(
|
||||
output_abi_types, comparator_value, comparator_index, failure_message
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r = (
|
||||
|
@ -481,10 +468,6 @@ class ContractCondition(RPCCondition):
|
|||
)
|
||||
return r
|
||||
|
||||
def _configure_provider(self, *args, **kwargs):
|
||||
super()._configure_provider(*args, **kwargs)
|
||||
self.contract_function.w3 = self.w3
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
|
|
|
@ -7,9 +7,9 @@ class InvalidConditionLingo(Exception):
|
|||
class NoConnectionToChain(RuntimeError):
|
||||
"""Raised when a node does not have an associated provider for a chain."""
|
||||
|
||||
def __init__(self, chain: int):
|
||||
def __init__(self, chain: int, message: str = None):
|
||||
self.chain = chain
|
||||
message = f"No connection to chain ID {chain}"
|
||||
message = message or f"No connection to chain ID {chain}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class TimeRPCCall(RPCCall):
|
|||
def _validate_method(self, method):
|
||||
return method
|
||||
|
||||
def execute(self, w3: Web3, **context) -> Any:
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
# TODO may need to rethink as part of #3051 (multicall work).
|
||||
latest_block = w3.eth.get_block("latest")
|
||||
|
|
|
@ -22,7 +22,9 @@ from nucypher.policy.conditions.evm import (
|
|||
RPCCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ContextVariableVerificationFailed,
|
||||
InvalidCondition,
|
||||
InvalidContextVariableData,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
|
@ -87,7 +89,7 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
|
|||
rpc_condition.rpc_call.chain = new_chain
|
||||
condition_providers = {new_chain: {testerchain.provider}}
|
||||
with pytest.raises(
|
||||
InvalidCondition, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -155,11 +157,7 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
|
|||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition, "_check_chain_id", return_value=None
|
||||
) # skip chain check
|
||||
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
|
||||
|
||||
mocker.patch.object(rpc_condition.rpc_call, "_configure_provider", my_configure_w3)
|
||||
with pytest.raises(RPCExecutionFailed):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -185,10 +183,7 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
|
|||
}
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition, "_check_chain_id", return_value=None
|
||||
) # skip chain check
|
||||
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
|
||||
mocker.patch.object(rpc_condition.rpc_call, "_configure_provider", my_configure_w3)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
|
|
|
@ -14,7 +14,7 @@ from nucypher.blockchain.eth.agents import (
|
|||
)
|
||||
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.evm import RPCCall
|
||||
from nucypher.utilities.logging import Logger
|
||||
from tests.constants import (
|
||||
BONUS_TOKENS_FOR_TESTS,
|
||||
|
@ -430,16 +430,11 @@ def taco_child_application_agent(testerchain, test_registry):
|
|||
#
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_rpc_condition(module_mocker, testerchain, monkeymodule):
|
||||
def configure_mock(condition, provider, *args, **kwargs):
|
||||
condition.provider = provider
|
||||
def mock_rpc_condition(testerchain, monkeymodule):
|
||||
def configure_mock(*args, **kwargs):
|
||||
return testerchain.w3
|
||||
|
||||
monkeymodule.setattr(RPCCondition, "_configure_w3", configure_mock)
|
||||
configure_spy = module_mocker.spy(RPCCondition, "_configure_w3")
|
||||
|
||||
chain_id_check_mock = module_mocker.patch.object(RPCCondition, "_check_chain_id")
|
||||
return configure_spy, chain_id_check_mock
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -9,7 +9,6 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
from hexbytes import HexBytes
|
||||
from marshmallow import post_load
|
||||
from web3 import Web3
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.evm import ContractCall, ContractCondition
|
||||
|
@ -53,7 +52,7 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_return_value = value
|
||||
|
||||
def execute(self, w3: Web3, **context) -> Any:
|
||||
def execute(self, providers: Dict, **context) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
class Schema(ContractCondition.Schema):
|
||||
|
@ -70,9 +69,6 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def set_execution_return_value(self, value: Any):
|
||||
self.rpc_call.set_execution_return_value(value)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
self.w3 = dict() # doesn't matter what it is
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def contract_condition_dict():
|
||||
|
|
Loading…
Reference in New Issue