Detect paths during deserialization

pull/2692/head
Piotr Roslaniec 2021-08-04 11:46:35 +02:00
parent 09e90d17fa
commit 695cc10950
5 changed files with 42 additions and 37 deletions

View File

@ -22,7 +22,7 @@ from abc import ABC, abstractmethod
from decimal import Decimal
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Union, get_type_hints
from typing import Callable, List, Optional, Union
from constant_sorrow.constants import (
DEVELOPMENT_CONFIGURATION,
@ -49,6 +49,7 @@ from nucypher.config.storages import (
LocalFileBasedNodeStorage,
NodeStorage
)
from nucypher.config.util import cast_paths_from
from nucypher.crypto.keystore import Keystore
from nucypher.crypto.powers import CryptoPower, CryptoPowerUp
from nucypher.crypto.umbral_adapter import Signature
@ -302,10 +303,7 @@ class BaseConfiguration(ABC):
raise cls.OldVersion(f"Configuration {label} is the wrong version "
f"Expected version {cls.VERSION}; Got version {version}")
if 'keyring_root' in deserialized_payload:
deserialized_payload['keyring_root'] = Path(deserialized_payload['keyring_root'])
if 'node_storage' in deserialized_payload:
deserialized_payload['node_storage'] = LocalFileBasedNodeStorage.format_payload(deserialized_payload['node_storage'])
deserialized_payload = cast_paths_from(cls, deserialized_payload)
return deserialized_payload
def update(self, filepath: Path = None, **updates) -> None:
@ -642,15 +640,7 @@ class CharacterConfiguration(BaseConfiguration):
# Assemble
payload.update(dict(node_storage=node_storage, max_gas_price=max_gas_price))
constructor_args = get_type_hints(cls.__init__)
constructor_args.update(get_type_hints(CharacterConfiguration.__init__))
paths_only = [
arg for (arg, type_) in constructor_args.items()
if type_ == Path or type_ == Optional[Path]
]
for key in paths_only:
if key in payload:
payload[key] = Path(payload[key]) if payload[key] else None
payload = cast_paths_from(cls, payload)
# Filter out None values from **overrides to detect, well, overrides...
# Acts as a shim for optional CLI flags.

View File

@ -19,7 +19,7 @@ import binascii
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional, Set, Union, get_type_hints
from typing import Any, Set, Union
import OpenSSL
from bytestring_splitter import BytestringSplittingError
@ -31,6 +31,7 @@ from cryptography.x509 import Certificate
from nucypher.blockchain.eth.decorators import validate_checksum_address
from nucypher.blockchain.eth.registry import BaseContractRegistry
from nucypher.config.constants import DEFAULT_CONFIG_ROOT
from nucypher.config.util import cast_paths_from
from nucypher.crypto.signing import SignatureStamp
from nucypher.utilities.logging import Logger
@ -188,7 +189,7 @@ class ForgetfulNodeStorage(NodeStorage):
certificate_only: bool = False):
if not bool(stamp) ^ bool(host):
message = "Either pass checksum_address or host; Not both. Got ({} {})".format(checksum_address, host)
message = "Either pass stamp or host; Not both. Got ({} {})".format(stamp, host)
raise ValueError(message)
if certificate_only is True:
@ -441,31 +442,14 @@ class LocalFileBasedNodeStorage(NodeStorage):
}
return payload
@classmethod
def format_payload(cls, payload: dict) -> dict:
output = {}
for key, value in payload.items():
if key in ['root_dir', 'metadata_dir', 'certificates_dir']:
output[key] = Path(payload[key])
else:
output[key] = payload[key]
return output
@classmethod
def from_payload(cls, payload: dict, *args, **kwargs) -> 'LocalFileBasedNodeStorage':
payload = cls.format_payload(payload)
storage_type = payload[cls._TYPE_LABEL]
if not storage_type == cls._name:
raise cls.NodeStorageError("Wrong storage type. got {}".format(storage_type))
del payload['storage_type']
paths_only = [
arg for (arg, type_) in get_type_hints(cls.__init__).items()
if type_ == Path or type_ == Optional[Path]
]
for key in paths_only:
if key in payload:
payload[key] = Path(payload[key])
payload = cast_paths_from(cls, payload)
return cls(*args, **payload, **kwargs)

33
nucypher/config/util.py Normal file
View File

@ -0,0 +1,33 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from pathlib import Path
from typing import Optional, get_type_hints
def cast_paths_from(cls, payload):
constructor_args = get_type_hints(cls.__init__)
for ancestor in cls.__mro__:
constructor_args.update(get_type_hints(ancestor.__init__))
paths_only = [
arg for (arg, type_) in constructor_args.items()
if type_ == Path or type_ == Optional[Path]
]
for key in paths_only:
if key in payload:
payload[key] = Path(payload[key]) if payload[key] else None
return payload

View File

@ -194,6 +194,7 @@ def test_configuration_preservation():
# Ensure file contents are JSON deserializable
deserialized_file_contents = json.loads(contents)
del deserialized_file_contents['version'] # do not test version of config serialization here.
deserialized_file_contents['config_root'] = Path(deserialized_file_contents['config_root'])
deserialized_payload = RestorableTestItem.deserialize(payload=contents)
assert deserialized_payload == deserialized_file_contents

View File

@ -14,15 +14,12 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import tempfile
from typing import List
from nucypher.blockchain.eth.registry import BaseContractRegistry
from nucypher.characters.lawful import Ursula
from nucypher.config.characters import AliceConfiguration, BobConfiguration, UrsulaConfiguration
from nucypher.config.constants import TEMPORARY_DOMAIN
from nucypher.crypto.keystore import Keystore
from tests.constants import INSECURE_DEVELOPMENT_PASSWORD
from tests.utils.middleware import MockRestMiddleware
from tests.utils.ursula import MOCK_URSULA_STARTING_PORT