nucypher/tests/mock/agents.py

198 lines
7.4 KiB
Python
Raw Normal View History

"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from enum import Enum
from constant_sorrow.constants import (CONTRACT_ATTRIBUTE, CONTRACT_CALL, TRANSACTION)
from hexbytes import HexBytes
from typing import Callable, Generator, Iterable, List, Type, Union
from unittest.mock import Mock
2020-05-19 19:30:41 +00:00
from nucypher.blockchain.eth import agents
2020-05-21 23:01:26 +00:00
from nucypher.blockchain.eth.agents import Agent, ContractAgency, EthereumContractAgent
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from tests.constants import MOCK_PROVIDER_URI
from tests.utils.blockchain import free_gas_price_strategy
MOCK_TESTERCHAIN = BlockchainInterfaceFactory.get_or_create_interface(provider_uri=MOCK_PROVIDER_URI,
gas_strategy=free_gas_price_strategy)
CURRENT_BLOCK = MOCK_TESTERCHAIN.w3.eth.getBlock(block_identifier='latest')
class MockContractAgent:
FAKE_RECEIPT = {'transactionHash': HexBytes(b'FAKE29890FAKE8349804'),
'gasUsed': 1,
'blockNumber': CURRENT_BLOCK.number,
'blockHash': HexBytes(b'FAKE43434343FAKE43443434')}
FAKE_CALL_RESULT = 1
# Internal
__COLLECTION_MARKER = "contract_api" # decorator attribute
__DEFAULTS = {
CONTRACT_CALL: FAKE_CALL_RESULT,
CONTRACT_ATTRIBUTE: FAKE_CALL_RESULT,
TRANSACTION: FAKE_RECEIPT,
}
_MOCK_METHODS = list()
_REAL_METHODS = list()
# Mock Nucypher Contract API
contract = Mock()
contract_address = NULL_ADDRESS
# Mock Blockchain Interfaces
registry = Mock()
blockchain = MOCK_TESTERCHAIN
def __init__(self, agent_class: Type[EthereumContractAgent]):
"""Bind mock agent attributes to the *subclass* with default values"""
self.agent_class = agent_class
self.__setup_mock(agent_class=agent_class)
def __repr__(self) -> str:
r = f'Mock{self.agent_class.__name__}(id={id(self)})'
return r
def __setup_mock(self, agent_class: Type[Agent]) -> None:
api_methods: Iterable[Callable] = list(self.__collect_contract_api(agent_class=agent_class))
mock_methods, mock_properties = list(), dict()
for agent_interface in api_methods:
# Handle
try:
2020-05-21 23:01:26 +00:00
# TODO: #2022: This might be a method also decorated @property
# Get the inner function of the property
real_method: Callable = agent_interface.fget # Handle properties
except AttributeError:
real_method = agent_interface
# Get
interface = getattr(real_method, self.__COLLECTION_MARKER)
default_return = self.__DEFAULTS.get(interface)
2020-05-19 19:30:41 +00:00
2020-05-21 23:01:26 +00:00
# TODO: #2022 Special handling of PropertyMocks?
# # Setup
# if interface == CONTRACT_ATTRIBUTE:
# mock = PropertyMock()
# mock_properties[real_method.__name__] = mock
# else:
mock = Mock(return_value=default_return)
# Mark
setattr(mock, self.__COLLECTION_MARKER, interface)
mock_methods.append(mock)
# Bind
setattr(self, real_method.__name__, mock)
self._MOCK_METHODS = mock_methods
self._REAL_METHODS = api_methods
def __get_interface_calls(self, interface: Enum) -> List[Callable]:
predicate = lambda method: bool(method.contract_api == interface)
interface_calls = list(filter(predicate, self._MOCK_METHODS))
return interface_calls
@classmethod
2020-05-21 23:01:26 +00:00
def __is_contract_method(cls, agent_class: Type[Agent], method_name: str) -> bool:
method_or_property = getattr(agent_class, method_name)
try:
real_method: Callable = method_or_property.fget # Property (getter)
except AttributeError:
real_method: Callable = method_or_property # Method
contract_api: bool = hasattr(real_method, cls.__COLLECTION_MARKER)
return contract_api
@classmethod
2020-05-21 23:01:26 +00:00
def __collect_contract_api(cls, agent_class: Type[Agent]) -> Generator[Callable, None, None]:
agent_attrs = dir(agent_class)
predicate = cls.__is_contract_method
methods = (getattr(agent_class, name) for name in agent_attrs if predicate(agent_class, name))
2020-05-19 19:30:41 +00:00
return methods
#
# Test Utilities
#
@property
def all_transactions(self) -> List[Callable]:
interface = TRANSACTION
transaction_functions = self.__get_interface_calls(interface=interface)
return transaction_functions
@property
def contract_calls(self) -> List[Callable]:
interface = CONTRACT_CALL
transaction_functions = self.__get_interface_calls(interface=interface)
return transaction_functions
def get_unexpected_transactions(self, allowed: Union[Iterable[Callable], None]) -> List[Callable]:
if allowed:
predicate = lambda tx: tx not in allowed and tx.called
else:
predicate = lambda tx: tx.called
unexpected_transactions = list(filter(predicate, self.all_transactions))
return unexpected_transactions
2020-05-21 17:20:25 +00:00
def assert_only_transactions(self, allowed: Iterable[Callable]) -> None:
unexpected_transactions = self.get_unexpected_transactions(allowed=allowed)
assert not bool(unexpected_transactions)
def assert_no_transactions(self) -> None:
unexpected_transactions = self.get_unexpected_transactions(allowed=None)
assert not bool(unexpected_transactions)
def reset(self, clear_side_effects: bool = True, clear_return_values: bool = True) -> None:
for mock in self._MOCK_METHODS:
mock.reset_mock(return_value=clear_return_values, side_effect=clear_side_effects)
if clear_return_values:
interface = getattr(mock, self.__COLLECTION_MARKER)
default_return = self.__DEFAULTS.get(interface)
mock.return_value = default_return
2020-05-19 19:30:41 +00:00
class MockContractAgency(ContractAgency):
__agents = dict()
@classmethod
def get_agent(cls, agent_class: Type[Agent], *args, **kwargs) -> MockContractAgent:
try:
mock_agent = cls.__agents[agent_class]
except KeyError:
mock_agent = MockContractAgent(agent_class=agent_class)
cls.__agents[agent_class] = mock_agent
2020-05-19 19:30:41 +00:00
return mock_agent
@classmethod
def get_agent_by_contract_name(cls, contract_name: str, *args, **kwargs) -> MockContractAgent:
2020-05-19 19:30:41 +00:00
agent_name = super()._contract_name_to_agent_name(name=contract_name)
agent_class = getattr(agents, agent_name)
mock_agent = cls.get_agent(agent_class=agent_class)
return mock_agent
@classmethod
def reset(cls) -> None:
for agent in cls.__agents.values():
agent.reset()