mirror of https://github.com/nucypher/nucypher.git
Iterating on spy wrapper and mock agents
parent
f583567222
commit
c8e37210fe
|
@ -1,24 +1,12 @@
|
|||
import click
|
||||
import pytest
|
||||
from io import StringIO
|
||||
|
||||
from nucypher.blockchain.eth.clients import EthereumTesterClient, PUBLIC_CHAINS
|
||||
from nucypher.blockchain.eth.token import NU
|
||||
from nucypher.characters.control.emitters import StdoutEmitter
|
||||
from nucypher.cli.actions.confirm import (
|
||||
confirm_deployment,
|
||||
confirm_enable_restaking_lock,
|
||||
confirm_enable_restaking,
|
||||
confirm_enable_winding_down,
|
||||
confirm_staged_stake,
|
||||
confirm_large_stake
|
||||
)
|
||||
from nucypher.cli.literature import (
|
||||
ABORT_DEPLOYMENT,
|
||||
RESTAKING_LOCK_AGREEMENT,
|
||||
RESTAKING_AGREEMENT,
|
||||
WINDING_DOWN_AGREEMENT
|
||||
)
|
||||
from nucypher.cli.actions.confirm import (confirm_deployment, confirm_enable_restaking, confirm_enable_restaking_lock,
|
||||
confirm_enable_winding_down, confirm_large_stake, confirm_staged_stake)
|
||||
from nucypher.cli.literature import (ABORT_DEPLOYMENT, RESTAKING_AGREEMENT, RESTAKING_LOCK_AGREEMENT,
|
||||
WINDING_DOWN_AGREEMENT)
|
||||
|
||||
|
||||
def test_confirm_deployment(mocker, mock_click_prompt, test_emitter, stdout_trap, mock_testerchain):
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import click
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from eth_utils import is_checksum_address
|
||||
from unittest.mock import Mock
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.blockchain.eth.actors import Wallet
|
||||
from nucypher.blockchain.eth.clients import EthereumClient, EthereumTesterClient
|
||||
from nucypher.blockchain.eth.clients import EthereumClient
|
||||
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
|
||||
from nucypher.blockchain.eth.token import NU
|
||||
from nucypher.cli.actions.select import select_client_account
|
||||
|
@ -112,7 +111,7 @@ def test_select_client_account_valid_sources(mocker,
|
|||
assert selected_account == expected_account
|
||||
|
||||
|
||||
@pytest.mark.parametrize('selection,show_staking,show_eth,show_tokens,mock_stakes', (
|
||||
@pytest.mark.parametrize('selection,show_staking,show_eth,show_tokens,stake_info', (
|
||||
(0, True, True, True, []),
|
||||
(1, True, True, True, []),
|
||||
(5, True, True, True, []),
|
||||
|
@ -134,10 +133,11 @@ def test_select_client_account_with_balance_display(mock_click_prompt,
|
|||
show_staking,
|
||||
show_eth,
|
||||
show_tokens,
|
||||
mock_stakes):
|
||||
stake_info):
|
||||
|
||||
# Setup
|
||||
mock_click_prompt.return_value = selection
|
||||
mock_staking_agent.get_all_stakes.return_value = stake_info
|
||||
|
||||
# Missing network kwarg with balance display active
|
||||
blockchain_read_required = any((show_staking, show_eth, show_tokens))
|
||||
|
@ -149,7 +149,7 @@ def test_select_client_account_with_balance_display(mock_click_prompt,
|
|||
show_staking=show_staking,
|
||||
provider_uri=MOCK_PROVIDER_URI)
|
||||
|
||||
mock_staking_agent.get_all_stakes.return_value = mock_stakes
|
||||
# Good selection
|
||||
selected_account = select_client_account(emitter=test_emitter,
|
||||
network=TEMPORARY_DOMAIN,
|
||||
show_eth_balance=show_eth,
|
||||
|
@ -194,7 +194,7 @@ def test_select_client_account_with_balance_display(mock_click_prompt,
|
|||
assert str(Web3.fromWei(balance, 'ether')) in row
|
||||
|
||||
if show_staking:
|
||||
if len(mock_stakes) == 0:
|
||||
if len(stake_info) == 0:
|
||||
assert "No" in row
|
||||
else:
|
||||
assert 'Yes' in row
|
||||
|
|
|
@ -43,29 +43,29 @@ def monkeymodule():
|
|||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
def mock_contract_agency(monkeymodule, module_mocker, token_economics):
|
||||
monkeymodule.setattr(ContractAgency, 'get_agent', MockContractAgency.get_agent)
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_contract_agency(monkeypatch, module_mocker, token_economics):
|
||||
monkeypatch.setattr(ContractAgency, 'get_agent', MockContractAgency.get_agent)
|
||||
module_mocker.patch.object(EconomicsFactory, 'get_economics', return_value=token_economics)
|
||||
yield MockContractAgency()
|
||||
monkeymodule.delattr(ContractAgency, 'get_agent')
|
||||
monkeypatch.delattr(ContractAgency, 'get_agent')
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_token_agent(mock_testerchain, token_economics, mock_contract_agency):
|
||||
mock_agent = mock_contract_agency.get_agent(MockNucypherToken)
|
||||
yield mock_agent
|
||||
mock_agent.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_worklock_agent(mock_testerchain, token_economics, mock_contract_agency):
|
||||
mock_agent = mock_contract_agency.get_agent(MockWorkLockAgent)
|
||||
yield mock_agent
|
||||
mock_agent.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_staking_agent(mock_testerchain, token_economics, mock_contract_agency):
|
||||
mock_agent = mock_contract_agency.get_agent(MockStakingAgent)
|
||||
yield mock_agent
|
||||
|
|
|
@ -29,7 +29,7 @@ from tests.constants import CLI_TEST_ENV, MOCK_PROVIDER_URI, YES
|
|||
from tests.mock.agents import FAKE_RECEIPT, MockWorkLockAgent
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture()
|
||||
def surrogate_bidder(mock_testerchain, test_registry, mock_worklock_agent):
|
||||
address = mock_testerchain.etherbase_account
|
||||
bidder = Bidder(checksum_address=address, registry=test_registry)
|
||||
|
@ -51,7 +51,7 @@ def test_status(click_runner, mock_worklock_agent, test_registry_source_manager)
|
|||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture()
|
||||
def bidding_command(token_economics, surrogate_bidder):
|
||||
minimum = token_economics.worklock_min_allowed_bid
|
||||
bid_value = random.randint(minimum, minimum*100)
|
||||
|
@ -110,9 +110,10 @@ def test_valid_bid(click_runner,
|
|||
|
||||
now = mock_testerchain.get_blocktime()
|
||||
sometime_later = now + 100
|
||||
mock_blocktime = mocker.patch.object(BlockchainInterface, 'get_blocktime', return_value=sometime_later)
|
||||
mocker.patch.object(BlockchainInterface, 'get_blocktime', return_value=sometime_later)
|
||||
|
||||
# Spy on the corresponding CLI function we are testing
|
||||
# TODO: Mock at the agent layer instead
|
||||
mock_ensure = mocker.spy(Bidder, 'ensure_bidding_is_open')
|
||||
mock_bidder = mocker.spy(Bidder, 'place_bid')
|
||||
|
||||
|
@ -365,7 +366,7 @@ def test_participant_status(click_runner,
|
|||
'get_bonus_deposit_rate',
|
||||
'get_bonus_refund_rate',
|
||||
'get_base_refund_rate',
|
||||
'get_completed_work',
|
||||
# 'get_completed_work', # TODO Yes or no?
|
||||
'get_refunded_work')
|
||||
# Calls
|
||||
mock_worklock_agent.assert_contract_calls(calls=expected_calls)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from typing import Tuple
|
||||
from unittest.mock import Mock
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import Mock, _CallList
|
||||
|
||||
from nucypher.blockchain.economics import EconomicsFactory
|
||||
from nucypher.blockchain.eth.agents import ContractAgency, NucypherTokenAgent, PolicyManagerAgent, StakingEscrowAgent, \
|
||||
|
@ -49,91 +49,83 @@ class MockContractAgent:
|
|||
# API
|
||||
# TODO: Auto generate calls and txs from class inspection
|
||||
|
||||
DEFAULT_TRANSACTION = default_fake_transaction()
|
||||
DEFAULT_CALL = default_fake_call()
|
||||
|
||||
ATTRS = dict()
|
||||
CALLS = tuple()
|
||||
TRANSACTIONS = tuple()
|
||||
|
||||
# Spy
|
||||
_SPY_TRANSACTIONS = defaultdict(list)
|
||||
_SPY_CALLS = defaultdict(list)
|
||||
|
||||
def __init__(self):
|
||||
# initial state
|
||||
self.spy = True
|
||||
self.setup_mock(agent_attrs=self.ATTRS)
|
||||
|
||||
self.spy = True # initial state
|
||||
|
||||
# Bind mock agent attributes to the *subclass*
|
||||
for agent_method, mock_value in self.ATTRS.items():
|
||||
setattr(self.__class__, agent_method, mock_value)
|
||||
|
||||
for call in self.CALLS:
|
||||
setattr(self.__class__, call, Mock(return_value=default_fake_call()))
|
||||
|
||||
for tx in self.TRANSACTIONS:
|
||||
setattr(self.__class__, tx, Mock(return_value=default_fake_transaction()))
|
||||
|
||||
def __record_tx(self, name: str, params: tuple) -> None:
|
||||
self._SPY_TRANSACTIONS[str(name)].append(params)
|
||||
|
||||
def __record_call(self, name: str, params: tuple) -> None:
|
||||
self._SPY_CALLS[str(name)].append(params)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""Spy"""
|
||||
|
||||
get = object.__getattribute__
|
||||
attr = get(self, name)
|
||||
if not get(self, 'spy'):
|
||||
return attr
|
||||
|
||||
transaction = name in get(self, 'TRANSACTIONS')
|
||||
call = name in get(self, 'CALLS')
|
||||
if not transaction or call:
|
||||
return attr
|
||||
|
||||
spy = self.__record_tx if transaction else self.__record_call
|
||||
|
||||
class Spy(attr):
|
||||
def __call__(self, *args, **kwargs):
|
||||
result = super().__call__(*args, **kwargs)
|
||||
params = args, kwargs
|
||||
spy(name, params)
|
||||
return result
|
||||
return Spy()
|
||||
@classmethod
|
||||
def setup_mock(cls, agent_attrs: dict = None):
|
||||
"""Bind mock agent attributes to the *subclass* with default values"""
|
||||
if not agent_attrs:
|
||||
agent_attrs = dict()
|
||||
for agent_method, mock_value in agent_attrs.items():
|
||||
setattr(cls, agent_method, mock_value)
|
||||
for call in cls.CALLS:
|
||||
setattr(cls, call, Mock(return_value=cls.DEFAULT_CALL))
|
||||
for tx in cls.TRANSACTIONS:
|
||||
setattr(cls, tx, Mock(return_value=cls.DEFAULT_TRANSACTION))
|
||||
|
||||
#
|
||||
# Utils
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
cls._SPY_TRANSACTIONS.clear()
|
||||
cls._SPY_CALLS.clear()
|
||||
def reset(self):
|
||||
for name in (*self.CALLS, *self.TRANSACTIONS):
|
||||
mock = getattr(self, name)
|
||||
mock.call_args_list = _CallList()
|
||||
|
||||
def __get_call_list(self, name_list: Tuple[str]) -> defaultdict:
|
||||
result = defaultdict(list)
|
||||
for name in name_list:
|
||||
mock = getattr(self, name)
|
||||
calls = mock.call_args_list
|
||||
if calls:
|
||||
result[name].extend(calls)
|
||||
return result
|
||||
|
||||
@property
|
||||
def spy_transactions(self) -> defaultdict:
|
||||
result = self.__get_call_list(name_list=self.TRANSACTIONS)
|
||||
return result
|
||||
|
||||
@property
|
||||
def spy_contract_calls(self) -> defaultdict:
|
||||
result = self.__get_call_list(name_list=self.CALLS)
|
||||
return result
|
||||
|
||||
#
|
||||
# Assertions
|
||||
#
|
||||
|
||||
def assert_any_transaction(self) -> None:
|
||||
assert self._SPY_TRANSACTIONS, 'No transactions performed'
|
||||
assert self.spy_transactions, 'No transactions performed'
|
||||
|
||||
def assert_no_transactions(self) -> None:
|
||||
assert not self._SPY_TRANSACTIONS, 'Transactions performed'
|
||||
assert not self.spy_transactions, 'Transactions performed'
|
||||
|
||||
def assert_only_one_transaction_executed(self) -> None:
|
||||
fail = f"{len(self._SPY_TRANSACTIONS)} were performed ({', '.join(self._SPY_TRANSACTIONS)})."
|
||||
assert len(self._SPY_TRANSACTIONS) == 1, fail
|
||||
fail = f"{len(self.spy_transactions)} were performed ({', '.join(self.spy_transactions)})."
|
||||
assert len(self.spy_transactions) == 1, fail
|
||||
|
||||
def assert_transaction_not_called(self, name: str) -> None:
|
||||
assert name not in self._SPY_TRANSACTIONS, f'Unexpected transaction call "{name}".'
|
||||
assert name not in self.spy_transactions, f'Unexpected transaction call "{name}".'
|
||||
|
||||
def assert_transaction(self, name: str, call_count: int = 1, **kwargs) -> None:
|
||||
|
||||
# some transaction
|
||||
assert self._SPY_TRANSACTIONS, 'No transactions performed'
|
||||
assert name in self.TRANSACTIONS, f'"{name}" was not performed. Recorded txs: ({" ,".join(self._SPY_TRANSACTIONS)})'
|
||||
assert self.spy_transactions, 'No transactions performed'
|
||||
assert name in self.TRANSACTIONS, f'"{name}" was not performed. Recorded txs: ({" ,".join(self.spy_transactions)})'
|
||||
|
||||
# this transaction
|
||||
transaction_executions = self._SPY_TRANSACTIONS[name]
|
||||
transaction_executions = self.spy_transactions[name]
|
||||
fail = f'Transaction "{name}" was called an unexpected number of times; ' \
|
||||
f'Expected {call_count} got {len(transaction_executions)}.'
|
||||
assert len(transaction_executions) == call_count, fail
|
||||
|
@ -144,7 +136,7 @@ class MockContractAgent:
|
|||
|
||||
def assert_contract_calls(self, calls: Tuple[str]) -> None:
|
||||
for call_name in calls:
|
||||
assert call_name in self._SPY_CALLS, f'"{call_name}" was not called'
|
||||
assert call_name in self.spy_contract_calls, f'"{call_name}" was not called'
|
||||
|
||||
|
||||
class MockNucypherToken(MockContractAgent, NucypherTokenAgent):
|
||||
|
|
Loading…
Reference in New Issue