mirror of https://github.com/nucypher/nucypher.git
94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from typing import List
|
|
|
|
from ape.contracts import ContractInstance
|
|
from eth_utils import to_checksum_address
|
|
|
|
from nucypher.blockchain.eth.domains import TACoDomain
|
|
from nucypher.blockchain.eth.registry import (
|
|
RegistryData,
|
|
RegistrySource,
|
|
RegistrySourceManager,
|
|
)
|
|
from nucypher.config.constants import TEMPORARY_DOMAIN_NAME
|
|
from tests.constants import TEMPORARY_DOMAIN
|
|
|
|
|
|
@contextmanager
|
|
def mock_registry_sources(mocker, _domains: List[TACoDomain] = None):
|
|
if not _domains:
|
|
_domains = [TEMPORARY_DOMAIN]
|
|
_supported_domains = mocker.patch.dict(
|
|
"nucypher.blockchain.eth.domains.SUPPORTED_DOMAINS",
|
|
{str(domain): domain for domain in _domains},
|
|
)
|
|
mocker.patch.object(MockRegistrySource, "ALLOWED_DOMAINS", list(map(str, _domains)))
|
|
mocker.patch.object(RegistrySourceManager, "_FALLBACK_CHAIN", (MockRegistrySource,))
|
|
yield
|
|
|
|
|
|
class MockRegistrySource(RegistrySource):
|
|
ALLOWED_DOMAINS = [TEMPORARY_DOMAIN_NAME]
|
|
|
|
name = "Mock Registry Source"
|
|
is_primary = False
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if str(self.domain) not in self.ALLOWED_DOMAINS:
|
|
raise ValueError(
|
|
f"Somehow, MockRegistrySource is trying to get a registry for '{self.domain}'. "
|
|
f"Only '{','.join(self.ALLOWED_DOMAINS)}' are supported.'"
|
|
)
|
|
|
|
@property
|
|
def registry_name(self) -> str:
|
|
return str(self.domain)
|
|
|
|
def get_publication_endpoint(self) -> str:
|
|
return f":mock-registry-source:/{self.registry_name}"
|
|
|
|
def get(self) -> RegistryData:
|
|
self.logger.debug(f"Reading registry at {self.get_publication_endpoint()}")
|
|
data = dict()
|
|
return data
|
|
|
|
|
|
class ApeRegistrySource(RegistrySource):
|
|
name = "Ape Registry Source"
|
|
is_primary = False
|
|
|
|
_DEPLOYMENTS = None
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if str(self.domain) != TEMPORARY_DOMAIN_NAME:
|
|
raise ValueError(
|
|
f"Somehow, ApeRegistrySource is trying to get a registry for '{self.domain}'. "
|
|
f"Only '{TEMPORARY_DOMAIN_NAME}' is supported.'"
|
|
)
|
|
if self._DEPLOYMENTS is None:
|
|
raise ValueError(
|
|
"ApeRegistrySource has not been initialized with deployments."
|
|
)
|
|
|
|
@classmethod
|
|
def set_deployments(cls, deployments: List[ContractInstance]):
|
|
cls._DEPLOYMENTS = deployments
|
|
|
|
def get_publication_endpoint(self) -> str:
|
|
return "ape"
|
|
|
|
def get(self) -> RegistryData:
|
|
data = defaultdict(dict)
|
|
for contract_instance in self._DEPLOYMENTS:
|
|
entry = {
|
|
"address": to_checksum_address(contract_instance.address),
|
|
"abi": [abi.dict() for abi in contract_instance.contract_type.abi],
|
|
}
|
|
chain_id = contract_instance.chain_manager.chain_id
|
|
contract_name = contract_instance.contract_type.name
|
|
data[chain_id][contract_name] = entry
|
|
return data
|