Merge pull request #3531 from derekpierre/w3-instance

Fix ATxM use of web3 instance provided by `nucypher`
pull/3524/head
Derek Pierre 2024-07-29 08:43:29 -04:00 committed by GitHub
commit 5656dead4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 36 deletions

View File

@ -0,0 +1 @@
ATxM instance did not pass correct web3 instance to underlying strategies.

View File

@ -8,10 +8,12 @@ from web3 import Web3
from web3._utils.threads import Timeout
from web3.contract.contract import Contract
from web3.exceptions import TimeExhausted, TransactionNotFound
from web3.middleware import geth_poa_middleware, simple_cache_middleware
from web3.types import TxReceipt, Wei
from nucypher.blockchain.eth.constants import (
AVERAGE_BLOCK_TIME_IN_SECONDS,
POA_CHAINS,
PUBLIC_CHAINS,
)
from nucypher.blockchain.middleware.retry import (
@ -79,6 +81,7 @@ class EthereumClient:
self._add_default_middleware()
def _add_default_middleware(self):
# retry request middleware
endpoint_uri = getattr(self.w3.provider, "endpoint_uri", "")
if "infura" in endpoint_uri:
self.log.debug("Adding Infura RPC retry middleware to client")
@ -90,6 +93,22 @@ class EthereumClient:
self.log.debug("Adding RPC retry middleware to client")
self.add_middleware(RetryRequestMiddleware)
# poa middleware
chain_id = self.chain_id
is_poa = chain_id in POA_CHAINS
self.log.debug(
f"Blockchain: {self.chain_name} (chain_id={chain_id}, poa={is_poa})"
)
if is_poa:
# proof-of-authority blockchain
self.log.debug("Injecting POA middleware at layer 0")
self.inject_middleware(geth_poa_middleware, layer=0)
# simple cache middleware
self.log.debug("Adding simple_cache_middleware")
self.add_middleware(simple_cache_middleware)
@property
def chain_name(self) -> str:
name = PUBLIC_CHAINS.get(self.chain_id, UNKNOWN_DEVELOPMENT_CHAIN_ID)

View File

@ -17,12 +17,10 @@ from eth_utils import to_checksum_address
from web3 import HTTPProvider, IPCProvider, Web3, WebsocketProvider
from web3.contract.contract import Contract, ContractConstructor, ContractFunction
from web3.exceptions import TimeExhausted
from web3.middleware import geth_poa_middleware, simple_cache_middleware
from web3.providers import BaseProvider
from web3.types import TxParams, TxReceipt
from nucypher.blockchain.eth.clients import EthereumClient
from nucypher.blockchain.eth.constants import POA_CHAINS
from nucypher.blockchain.eth.decorators import validate_checksum_address
from nucypher.blockchain.eth.providers import (
_get_http_provider,
@ -241,14 +239,7 @@ class BlockchainInterface:
self.w3 = NO_BLOCKCHAIN_CONNECTION
self.client: EthereumClient = NO_BLOCKCHAIN_CONNECTION
self.is_light = light
speedup_strategy = ExponentialSpeedupStrategy(
w3=self.w3,
min_time_between_speedups=120,
) # speedup txs if not mined after 2 mins.
self.tx_machine = AutomaticTxMachine(
w3=self.w3, tx_exec_timeout=self.TIMEOUT, strategies=[speedup_strategy]
)
self.tx_machine = None
# TODO: Not ready to give users total flexibility. Let's stick for the moment to known values. See #2447
if gas_strategy not in (
@ -292,24 +283,6 @@ class BlockchainInterface:
gas_strategy = cls.GAS_STRATEGIES[cls.DEFAULT_GAS_STRATEGY]
return gas_strategy
def attach_middleware(self):
chain_id = int(self.client.chain_id)
self.poa = chain_id in POA_CHAINS
self.log.debug(
f"Blockchain: {self.client.chain_name} (chain_id={chain_id}, poa={self.poa})"
)
# For use with Proof-Of-Authority test-blockchains
if self.poa is True:
self.log.debug("Injecting POA middleware at layer 0")
self.client.inject_middleware(geth_poa_middleware, layer=0)
self.log.debug("Adding simple_cache_middleware")
self.client.add_middleware(simple_cache_middleware)
# TODO: See #2770
# self.configure_gas_strategy()
def configure_gas_strategy(self, gas_strategy: Optional[Callable] = None) -> None:
if gas_strategy:
@ -337,6 +310,10 @@ class BlockchainInterface:
# self.log.debug(f"Gas strategy currently reports a gas price of {gwei_gas_price} gwei.")
def connect(self):
if self.is_connected:
# safety check - connect was already previously called
return
endpoint = self.endpoint
self.log.info(f"Using external Web3 Provider '{self.endpoint}'")
@ -349,11 +326,19 @@ class BlockchainInterface:
if self._provider is NO_BLOCKCHAIN_CONNECTION:
raise self.NoProvider("There are no configured blockchain providers")
# Connect if not connected
try:
self.w3 = self.Web3(provider=self._provider)
self.tx_machine.w3 = self.w3 # share this web3 instance with the tracker
# client mutates w3 instance (configures middleware etc.)
self.client = EthereumClient(w3=self.w3)
# web3 instance fully configured; share instance with ATxM and respective strategies
speedup_strategy = ExponentialSpeedupStrategy(
w3=self.w3,
min_time_between_speedups=120,
) # speedup txs if not mined after 2 mins.
self.tx_machine = AutomaticTxMachine(
w3=self.w3, tx_exec_timeout=self.TIMEOUT, strategies=[speedup_strategy]
)
except requests.ConnectionError: # RPC
raise self.ConnectionFailed(
f"Connection Failed - {str(self.endpoint)} - is RPC enabled?"
@ -362,8 +347,6 @@ class BlockchainInterface:
raise self.ConnectionFailed(
f"Connection Failed - {str(self.endpoint)} - is IPC enabled?"
)
else:
self.attach_middleware()
return self.is_connected

View File

@ -10,11 +10,11 @@ CHAIN_ID = 23
@pytest.mark.parametrize("chain_id_return_value", [hex(CHAIN_ID), CHAIN_ID])
def test_cached_chain_id(mocker, chain_id_return_value):
web3_mock = mocker.MagicMock()
mock_client = EthereumClient(w3=web3_mock)
chain_id_property_mock = PropertyMock(return_value=chain_id_return_value)
type(web3_mock.eth).chain_id = chain_id_property_mock
mock_client = EthereumClient(w3=web3_mock)
assert mock_client.chain_id == CHAIN_ID
chain_id_property_mock.assert_called_once()

View File

@ -15,11 +15,11 @@ CHAIN_ID = 11155111 # pretend to be sepolia
@pytest.mark.parametrize("chain_id_return_value", [hex(CHAIN_ID), CHAIN_ID])
def test_cached_chain_id(mocker, chain_id_return_value):
web3_mock = mocker.MagicMock()
mock_client = EthereumClient(w3=web3_mock)
chain_id_property_mock = PropertyMock(return_value=chain_id_return_value)
type(web3_mock.eth).chain_id = chain_id_property_mock
mock_client = EthereumClient(w3=web3_mock)
assert mock_client.chain_id == CHAIN_ID
chain_id_property_mock.assert_called_once()