diff --git a/nucypher/blockchain/eth/chains.py b/nucypher/blockchain/eth/chains.py index 38c9aef70..c1ff56f0c 100644 --- a/nucypher/blockchain/eth/chains.py +++ b/nucypher/blockchain/eth/chains.py @@ -45,13 +45,13 @@ class Blockchain: @classmethod def connect(cls, provider_uri: str = None, - registry_filepath: str = None, + registry: EthereumContractRegistry = None, deployer: bool = False, compile: bool = False, ) -> 'Blockchain': if cls._instance is NO_BLOCKCHAIN_AVAILABLE: - registry = EthereumContractRegistry(registry_filepath=registry_filepath) + registry = registry or EthereumContractRegistry() compiler = SolidityCompiler() if compile is True else None InterfaceClass = BlockchainDeployerInterface if deployer is True else BlockchainInterface interface = InterfaceClass(provider_uri=provider_uri, registry=registry, compiler=compiler) diff --git a/nucypher/blockchain/eth/registry.py b/nucypher/blockchain/eth/registry.py index 960a80aa4..851815b8b 100644 --- a/nucypher/blockchain/eth/registry.py +++ b/nucypher/blockchain/eth/registry.py @@ -19,8 +19,12 @@ class EthereumContractRegistry: WARNING: Unless you are developing NuCypher, you most likely won't ever need to use this. """ + + _multi_contract = True + _contract_name = NotImplemented + # TODO: Integrate with config classes - __default_registry_path = os.path.join(DEFAULT_CONFIG_ROOT, 'contract_registry.json') + _default_registry_filepath = os.path.join(DEFAULT_CONFIG_ROOT, 'contract_registry.json') class RegistryError(Exception): pass @@ -37,9 +41,9 @@ class EthereumContractRegistry: class IllegalRegistry(RegistryError): """Raised when invalid data is encountered in the registry""" - def __init__(self, registry_filepath: str = __default_registry_path) -> None: + def __init__(self, registry_filepath: str = None) -> None: self.log = getLogger("registry") - self.__filepath = registry_filepath + self.__filepath = registry_filepath or self._default_registry_filepath @property def filepath(self): @@ -60,7 +64,7 @@ class EthereumContractRegistry: registry_file.write(json.dumps(registry_data)) registry_file.truncate() - def read(self) -> list: + def read(self) -> Union[list, dict]: """ Reads the registry file and parses the JSON and returns a list. If the file is empty it will return an empty list. @@ -77,7 +81,7 @@ class EthereumContractRegistry: if file_data: registry_data = json.loads(file_data) else: - registry_data = list() + registry_data = list() if self._multi_contract else dict() except FileNotFoundError: raise self.NoRegistry("No registry at filepath: {}".format(self.__filepath)) @@ -109,7 +113,7 @@ class EthereumContractRegistry: def search(self, contract_name: str=None, contract_address: str=None): """ Searches the registry for a contract with the provided name or address - and returns the contracts. + and returns the contracts component data. """ if not (bool(contract_name) ^ bool(contract_address)): raise ValueError("Pass contract_name or contract_address, not both.") @@ -173,7 +177,7 @@ class InMemoryEthereumContractRegistry(EthereumContractRegistry): self.__registry_data = None # type: str def clear(self): - self.__registry_data = list() + self.__registry_data = None def _swap_registry(self, filepath: str) -> bool: raise NotImplementedError @@ -186,7 +190,95 @@ class InMemoryEthereumContractRegistry(EthereumContractRegistry): registry_data = json.loads(self.__registry_data) except TypeError: if self.__registry_data is None: - registry_data = list() + registry_data = list() if self._multi_contract else dict() + else: + raise + return registry_data + + +class AllocationRegistry(EthereumContractRegistry): + + _multi_contract = False + _contract_name = 'UserEscrow' + + _default_registry_filepath = os.path.join(DEFAULT_CONFIG_ROOT, 'allocation_registry.json') + + class NoAllocationRegistry(EthereumContractRegistry.NoRegistry): + pass + + class AllocationEnrollmentError(RuntimeError): + pass + + class UnknownBeneficiary(ValueError): + pass + + def search(self, beneficiary_address: str = None, contract_address: str=None): + if not (bool(beneficiary_address) ^ bool(contract_address)): + raise ValueError("Pass contract_owner or contract_address, not both.") + + try: + allocation_data = self.read() + except EthereumContractRegistry.NoRegistry: + raise self.NoAllocationRegistry + + if beneficiary_address: + try: + contract_data = allocation_data[beneficiary_address] + except KeyError: + raise self.UnknownBeneficiary + + elif contract_address: + records = list() + for beneficiary_address, contract_data in allocation_data.items(): + contract_address, contract_abi = contract_data['address'], contract_data['abi'] + records.append(dict(address=contract_address, abi=contract_abi)) + if len(records) > 1: + raise self.RegistryError("Multiple {} deployments for beneficiary {}".format(self._contract_name, beneficiary_address)) + else: + contract_data = records[0] + + else: + raise ValueError("Beneficiary address or contract address must be supplied.") + + return contract_data + + def enroll(self, beneficiary_address, contract_address, contract_abi) -> None: + contract_data = [contract_address, contract_abi] + try: + allocation_data = self.read() + except self.RegistryError: + self.log.info("Blank allocation registry encountered: enrolling {}:{}".format(beneficiary_address, contract_address)) + allocation_data = list() if self._multi_contract else dict() # empty registry + + if beneficiary_address in allocation_data: + raise self.AllocationEnrollmentError("There is an existing {} deployment for {}".format(self._contract_name, beneficiary_address)) + + allocation_data[beneficiary_address] = contract_data + self.write(allocation_data) + self.log.info("Enrolled {}:{} into allocation registry {}".format(beneficiary_address, contract_address, self.filepath)) + + +class InMemoryAllocationRegistry(AllocationRegistry): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(registry_filepath=":memory:", *args, **kwargs) + self.__registry_data = None # type: str + + def clear(self): + self.__registry_data = dict() + + def _swap_registry(self, filepath: str) -> bool: + raise NotImplementedError + + def write(self, registry_data: list) -> None: + self.__registry_data = json.dumps(registry_data) + + def read(self) -> list: + try: + registry_data = json.loads(self.__registry_data) + except TypeError: + if self.__registry_data is None: + registry_data = dict() else: raise return registry_data