mirror of https://github.com/nucypher/nucypher.git
parent
12c2034ed5
commit
8cb8f16370
|
@ -22,14 +22,11 @@ from collections import deque
|
|||
from collections.abc import KeysView
|
||||
from typing import Optional, Dict, Iterable, List, Tuple, NamedTuple, Union, Any
|
||||
|
||||
import binascii
|
||||
import itertools
|
||||
import maya
|
||||
from eth_typing import ChecksumAddress
|
||||
|
||||
from nucypher_core import FleetStateChecksum, NodeMetadata
|
||||
|
||||
from ..crypto.utils import keccak_digest
|
||||
from nucypher.utilities.logging import Logger
|
||||
from .nicknames import Nickname
|
||||
|
||||
|
|
|
@ -23,4 +23,3 @@ from nucypher.characters.control.specifications.fields.label import *
|
|||
from nucypher.characters.control.specifications.fields.cleartext import *
|
||||
from nucypher.characters.control.specifications.fields.misc import *
|
||||
from nucypher.characters.control.specifications.fields.file import *
|
||||
from nucypher.characters.control.specifications.fields.signature import *
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from marshmallow import fields
|
||||
|
||||
from nucypher.control.specifications.exceptions import InvalidInputData, InvalidNativeDataTypes
|
||||
from nucypher.control.specifications.fields.base import BaseField
|
||||
from nucypher.crypto.umbral_adapter import Signature
|
||||
|
||||
|
||||
class UmbralSignature(BaseField, fields.Field):
|
||||
|
||||
def _serialize(self, value: Signature, attr, obj, **kwargs):
|
||||
return b64encode(bytes(value)).decode()
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
try:
|
||||
return Signature.from_bytes(b64decode(value))
|
||||
except InvalidNativeDataTypes as e:
|
||||
raise InvalidInputData(f"Could not parse {self.name}: {e}")
|
||||
|
||||
def _validate(self, value):
|
||||
try:
|
||||
Signature.from_bytes(value)
|
||||
except InvalidNativeDataTypes as e:
|
||||
raise InvalidInputData(f"Could not parse {self.name}: {e}")
|
|
@ -21,7 +21,6 @@ from http import HTTPStatus
|
|||
import json
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from json.decoder import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
|
@ -39,11 +38,9 @@ from constant_sorrow.constants import (
|
|||
from cryptography.hazmat.primitives.serialization import Encoding
|
||||
from cryptography.x509 import Certificate, NameOID
|
||||
from eth_typing.evm import ChecksumAddress
|
||||
from eth_utils import to_canonical_address, to_checksum_address
|
||||
from flask import Response, request
|
||||
from twisted.internet import reactor, stdio
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.logger import Logger
|
||||
from web3.types import TxReceipt
|
||||
|
||||
|
@ -88,7 +85,6 @@ from nucypher.crypto.umbral_adapter import (
|
|||
reencrypt,
|
||||
VerifiedKeyFrag,
|
||||
)
|
||||
from nucypher.datastore.datastore import DatastoreTransactionError, RecordNotFound
|
||||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||
from nucypher.network.middleware import RestMiddleware
|
||||
from nucypher.network.nodes import NodeSprout, TEACHER_NODES, Teacher
|
||||
|
@ -360,7 +356,7 @@ class Alice(Character, BlockchainPolicyAuthor):
|
|||
return policy_pubkey
|
||||
|
||||
def revoke(self,
|
||||
policy: 'Policy',
|
||||
policy: Policy,
|
||||
onchain: bool = True, # forced to False for federated mode
|
||||
offchain: bool = True
|
||||
) -> Tuple[TxReceipt, Dict[ChecksumAddress, Tuple['Revocation', Exception]]]:
|
||||
|
@ -766,8 +762,7 @@ class Ursula(Teacher, Character, Worker):
|
|||
|
||||
self.rest_server = self._make_local_server(host=rest_host,
|
||||
port=rest_port,
|
||||
db_filepath=db_filepath,
|
||||
domain=domain)
|
||||
db_filepath=db_filepath)
|
||||
|
||||
# Self-signed TLS certificate of self for Teacher.__init__
|
||||
certificate_filepath = self._crypto_power.power_ups(TLSHostingPower).keypair.certificate_filepath
|
||||
|
@ -809,11 +804,10 @@ class Ursula(Teacher, Character, Worker):
|
|||
self._crypto_power.consume_power_up(tls_hosting_power) # Consume!
|
||||
return tls_hosting_power
|
||||
|
||||
def _make_local_server(self, host, port, domain, db_filepath) -> ProxyRESTServer:
|
||||
def _make_local_server(self, host, port, db_filepath) -> ProxyRESTServer:
|
||||
rest_app, datastore = make_rest_app(
|
||||
this_node=self,
|
||||
db_filepath=db_filepath,
|
||||
domain=domain,
|
||||
)
|
||||
rest_server = ProxyRESTServer(rest_host=host,
|
||||
rest_port=port,
|
||||
|
|
1025
nucypher/core.py
1025
nucypher/core.py
File diff suppressed because it is too large
Load Diff
|
@ -26,14 +26,10 @@ from nucypher_core.umbral import (
|
|||
Signature,
|
||||
Signer,
|
||||
Capsule,
|
||||
KeyFrag,
|
||||
VerifiedKeyFrag,
|
||||
CapsuleFrag,
|
||||
VerifiedCapsuleFrag,
|
||||
VerificationError,
|
||||
encrypt,
|
||||
decrypt_original,
|
||||
generate_kfrags,
|
||||
reencrypt,
|
||||
decrypt_reencrypted,
|
||||
)
|
||||
|
|
|
@ -20,7 +20,7 @@ import random
|
|||
from typing import Dict, Sequence, List
|
||||
|
||||
from eth_typing.evm import ChecksumAddress
|
||||
from eth_utils import to_checksum_address, to_canonical_address
|
||||
from eth_utils import to_checksum_address
|
||||
from twisted.logger import Logger
|
||||
|
||||
from nucypher_core import (
|
||||
|
|
|
@ -23,7 +23,7 @@ from pathlib import Path
|
|||
from typing import Tuple
|
||||
|
||||
from constant_sorrow import constants
|
||||
from constant_sorrow.constants import RELAX, NOT_STAKING
|
||||
from constant_sorrow.constants import RELAX
|
||||
from flask import Flask, Response, jsonify, request
|
||||
from mako import exceptions as mako_exceptions
|
||||
from mako.template import Template
|
||||
|
@ -31,13 +31,11 @@ from mako.template import Template
|
|||
from nucypher_core import (
|
||||
ReencryptionRequest,
|
||||
RevocationOrder,
|
||||
NodeMetadata,
|
||||
MetadataRequest,
|
||||
MetadataResponse,
|
||||
MetadataResponsePayload,
|
||||
)
|
||||
|
||||
from nucypher.blockchain.eth.utils import period_to_epoch
|
||||
from nucypher.config.constants import MAX_UPLOAD_CONTENT_LENGTH
|
||||
from nucypher.crypto.keypairs import DecryptingKeypair
|
||||
from nucypher.crypto.signing import InvalidSignature
|
||||
|
@ -82,7 +80,6 @@ class ProxyRESTServer:
|
|||
def make_rest_app(
|
||||
db_filepath: Path,
|
||||
this_node,
|
||||
domain,
|
||||
log: Logger = Logger("http-application-layer")
|
||||
) -> Tuple[Flask, Datastore]:
|
||||
"""
|
||||
|
@ -99,12 +96,12 @@ def make_rest_app(
|
|||
|
||||
log.info("Starting datastore {}".format(db_filepath))
|
||||
datastore = Datastore(db_filepath)
|
||||
rest_app = _make_rest_app(weakref.proxy(datastore), weakref.proxy(this_node), domain, log)
|
||||
rest_app = _make_rest_app(weakref.proxy(datastore), weakref.proxy(this_node), log)
|
||||
|
||||
return rest_app, datastore
|
||||
|
||||
|
||||
def _make_rest_app(datastore: Datastore, this_node, domain: str, log: Logger) -> Flask:
|
||||
def _make_rest_app(datastore: Datastore, this_node, log: Logger) -> Flask:
|
||||
|
||||
# TODO: Avoid circular imports :-(
|
||||
from nucypher.characters.lawful import Alice, Bob, Ursula
|
||||
|
|
|
@ -1,179 +0,0 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
import re
|
||||
from typing import Dict, Tuple, Callable
|
||||
|
||||
|
||||
class Versioned(ABC):
|
||||
"""Base class for serializable entities"""
|
||||
|
||||
_VERSION_PARTS = 2
|
||||
_VERSION_PART_SIZE = 2 # bytes
|
||||
_BRAND_SIZE = 4
|
||||
_VERSION_SIZE = _VERSION_PART_SIZE * _VERSION_PARTS
|
||||
_HEADER_SIZE = _BRAND_SIZE + _VERSION_SIZE
|
||||
|
||||
class InvalidHeader(ValueError):
|
||||
"""Raised when an unexpected or invalid bytes header is encountered."""
|
||||
|
||||
class IncompatibleVersion(ValueError):
|
||||
"""Raised when attempting to deserialize incompatible bytes"""
|
||||
|
||||
class Empty(ValueError):
|
||||
"""Raised when 0 bytes are remaining after parsing the header."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _brand(cls) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
"""tuple(major, minor)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version_string(cls) -> str:
|
||||
major, minor = cls._version()
|
||||
return f'{major}.{minor}'
|
||||
|
||||
#
|
||||
# Serialize
|
||||
#
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self._header() + self._payload()
|
||||
|
||||
@classmethod
|
||||
def _header(cls) -> bytes:
|
||||
"""The entire bytes header to prepend to the instance payload."""
|
||||
major, minor = cls._version()
|
||||
major_bytes = major.to_bytes(cls._VERSION_PART_SIZE, 'big')
|
||||
minor_bytes = minor.to_bytes(cls._VERSION_PART_SIZE, 'big')
|
||||
header = cls._brand() + major_bytes + minor_bytes
|
||||
return header
|
||||
|
||||
@abstractmethod
|
||||
def _payload(self) -> bytes:
|
||||
"""The unbranded and unversioned bytes-serialized representation of this instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Deserialize
|
||||
#
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
"""The current deserializer"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _old_version_handlers(cls) -> Dict[Tuple[int, int], Callable]:
|
||||
"""Old deserializer callables keyed by version."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def take(cls, data: bytes):
|
||||
"""
|
||||
Deserializes the object from the given bytestring
|
||||
and returns the object and the remainder of the bytestring.
|
||||
"""
|
||||
brand, version, payload = cls._parse_header(data)
|
||||
version = cls._resolve_version(version=version)
|
||||
handlers = cls._deserializers()
|
||||
obj, remainder = handlers[version](payload)
|
||||
return obj, remainder
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
""""Public deserialization API"""
|
||||
obj, remainder = cls.take(data)
|
||||
if remainder:
|
||||
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:
|
||||
|
||||
# Unpack version metadata
|
||||
bytrestring_major, bytrestring_minor = version
|
||||
latest_major_version, latest_minor_version = cls._version()
|
||||
|
||||
# Enforce major version compatibility
|
||||
if not bytrestring_major == latest_major_version:
|
||||
message = f'Incompatible versioned bytes for {cls.__name__}. ' \
|
||||
f'Compatible version is {latest_major_version}.x, ' \
|
||||
f'Got {bytrestring_major}.{bytrestring_minor}.'
|
||||
raise cls.IncompatibleVersion(message)
|
||||
|
||||
# Enforce minor version compatibility.
|
||||
# Pass future minor versions to the latest minor handler.
|
||||
if bytrestring_minor >= latest_minor_version:
|
||||
version = cls._version()
|
||||
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
def _parse_header(cls, data: bytes) -> Tuple[bytes, Tuple[int, int], bytes]:
|
||||
if len(data) < cls._HEADER_SIZE:
|
||||
# handles edge case when input is too short.
|
||||
raise ValueError(f'Invalid bytes for {cls.__name__}.')
|
||||
brand = cls._parse_brand(data)
|
||||
version = cls._parse_version(data)
|
||||
payload = cls._parse_payload(data)
|
||||
return brand, version, payload
|
||||
|
||||
@classmethod
|
||||
def _parse_brand(cls, data: bytes) -> bytes:
|
||||
brand = data[:cls._BRAND_SIZE]
|
||||
if brand != cls._brand():
|
||||
error = f"Incorrect brand. Expected {cls._brand()}, Got {brand}."
|
||||
if not re.fullmatch(rb'\w+', brand):
|
||||
# unversioned entities for older versions will most likely land here.
|
||||
error = f"Incompatible bytes for {cls.__name__}."
|
||||
raise cls.InvalidHeader(error)
|
||||
return brand
|
||||
|
||||
@classmethod
|
||||
def _parse_version(cls, data: bytes) -> Tuple[int, int]:
|
||||
version_data = data[cls._BRAND_SIZE:cls._HEADER_SIZE]
|
||||
major, minor = version_data[:cls._VERSION_PART_SIZE], version_data[cls._VERSION_PART_SIZE:]
|
||||
major, minor = int.from_bytes(major, 'big'), int.from_bytes(minor, 'big')
|
||||
version = major, minor
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
def _parse_payload(cls, data: bytes) -> bytes:
|
||||
payload = data[cls._HEADER_SIZE:]
|
||||
if len(payload) == 0:
|
||||
raise ValueError(f'No content to deserialize {cls.__name__}.')
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def _deserializers(cls) -> Dict[Tuple[int, int], Callable]:
|
||||
"""Return a dict of all known deserialization handlers for this class keyed by version"""
|
||||
return {cls._version(): cls._from_bytes_current, **cls._old_version_handlers()}
|
||||
|
||||
|
||||
# Collects the brands of every serializable entity, potentially useful for documentation.
|
||||
# SERIALIZABLE_ENTITIES = {v.__class__.__name__: v._brand() for v in Versioned.__subclasses__()}
|
|
@ -17,9 +17,13 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from eth_tester.exceptions import TransactionFailed
|
||||
from nucypher.crypto.umbral_adapter import Signer, SecretKey, generate_kfrags, encrypt, reencrypt
|
||||
import pytest
|
||||
|
||||
from nucypher_core import MessageKit
|
||||
|
||||
from nucypher.crypto.umbral_adapter import Signer, SecretKey, generate_kfrags, reencrypt
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -44,7 +48,7 @@ def fragments():
|
|||
sign_delegating_key=False,
|
||||
sign_receiving_key=False)
|
||||
|
||||
capsule, _ciphertext = encrypt(delegating_pubkey, b'unused')
|
||||
capsule = MessageKit(delegating_pubkey, b'unused').capsule
|
||||
cfrag = reencrypt(capsule, kfrags[0])
|
||||
return capsule, cfrag
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from flask import Response
|
|||
|
||||
from nucypher.characters.lawful import Ursula
|
||||
from nucypher.crypto.signing import SignatureStamp
|
||||
from nucypher.crypto.umbral_adapter import SecretKey, Signer, PublicKey, encrypt
|
||||
from nucypher.crypto.umbral_adapter import SecretKey, Signer, PublicKey
|
||||
from nucypher.datastore.base import RecordField
|
||||
from nucypher.network.nodes import Teacher
|
||||
from tests.markers import skip_on_circleci
|
||||
|
@ -107,11 +107,6 @@ def test_alice_verifies_ursula_just_in_time(fleet_of_highperf_mocked_ursulas,
|
|||
highperf_mocked_alice,
|
||||
highperf_mocked_bob):
|
||||
|
||||
def mock_encrypt(public_key, plaintext):
|
||||
if not isinstance(public_key, PublicKey):
|
||||
public_key = public_key.i_want_to_be_a_real_boy()
|
||||
return encrypt(public_key, plaintext)
|
||||
|
||||
mocks = (
|
||||
mock_pubkey_from_bytes(),
|
||||
mock_secret_source(),
|
||||
|
|
|
@ -173,8 +173,7 @@ class NotARestApp:
|
|||
def actual_rest_app(self):
|
||||
if self._actual_rest_app is None:
|
||||
self._actual_rest_app, self._datastore = make_rest_app(db_filepath=self.db_filepath,
|
||||
this_node=self.this_node,
|
||||
domain=None)
|
||||
this_node=self.this_node)
|
||||
_new_view_functions = self._ViewFunctions(self._actual_rest_app.view_functions)
|
||||
self._actual_rest_app.view_functions = _new_view_functions
|
||||
self._actual_rest_apps.append(
|
||||
|
|
|
@ -30,7 +30,6 @@ from nucypher.characters.control.specifications.fields import (
|
|||
FileField,
|
||||
Key,
|
||||
MessageKit,
|
||||
UmbralSignature,
|
||||
EncryptedTreasureMap
|
||||
)
|
||||
from nucypher.characters.lawful import Enrico
|
||||
|
@ -135,30 +134,6 @@ def test_message_kit(enacted_federated_policy, federated_alice):
|
|||
field._deserialize(value=b"MessageKit", attr=None, data=None)
|
||||
|
||||
|
||||
def test_umbral_signature():
|
||||
umbral_priv_key = SecretKey.random()
|
||||
signer = Signer(umbral_priv_key)
|
||||
|
||||
message = b'this is a message'
|
||||
signature = signer.sign(message)
|
||||
other_signature = signer.sign(b'this is a different message')
|
||||
|
||||
field = UmbralSignature()
|
||||
serialized = field._serialize(value=signature, attr=None, obj=None)
|
||||
assert serialized == b64encode(bytes(signature)).decode()
|
||||
assert serialized != b64encode(bytes(other_signature)).decode()
|
||||
|
||||
deserialized = field._deserialize(value=serialized, attr=None, data=None)
|
||||
assert deserialized == signature
|
||||
assert deserialized != other_signature
|
||||
|
||||
field._validate(value=bytes(signature))
|
||||
field._validate(value=bytes(other_signature))
|
||||
|
||||
with pytest.raises(InvalidInputData):
|
||||
field._validate(value=b"UmbralSignature")
|
||||
|
||||
|
||||
def test_treasure_map(enacted_federated_policy):
|
||||
treasure_map = enacted_federated_policy.treasure_map
|
||||
|
||||
|
|
|
@ -21,11 +21,11 @@ from eth_utils import to_canonical_address
|
|||
|
||||
import pytest
|
||||
|
||||
from nucypher_core import RetrievalKit as RetrievalKitClass
|
||||
from nucypher_core import RetrievalKit as RetrievalKitClass, MessageKit
|
||||
|
||||
from nucypher.control.specifications.exceptions import InvalidInputData
|
||||
from nucypher.control.specifications.fields import StringList
|
||||
from nucypher.crypto.umbral_adapter import SecretKey, encrypt
|
||||
from nucypher.crypto.umbral_adapter import SecretKey
|
||||
from nucypher.utilities.porter.control.specifications.fields import UrsulaChecksumAddress
|
||||
from nucypher.utilities.porter.control.specifications.fields.retrieve import RetrievalKit
|
||||
|
||||
|
@ -104,13 +104,13 @@ def test_retrieval_kit_field(get_random_checksum_address):
|
|||
|
||||
# kit with list of ursulas
|
||||
encrypting_key = SecretKey.random().public_key()
|
||||
capsule, _ = encrypt(encrypting_key, b'testing retrieval kit with 2 ursulas')
|
||||
capsule = MessageKit(encrypting_key, b'testing retrieval kit with 2 ursulas').capsule
|
||||
ursulas = [get_random_checksum_address(), get_random_checksum_address()]
|
||||
run_tests_on_kit(kit=RetrievalKitClass(capsule, {to_canonical_address(ursula) for ursula in ursulas}))
|
||||
|
||||
# kit with no ursulas
|
||||
encrypting_key = SecretKey.random().public_key()
|
||||
capsule, _ = encrypt(encrypting_key, b'testing retrieval kit with no ursulas')
|
||||
capsule = MessageKit(encrypting_key, b'testing retrieval kit with no ursulas').capsule
|
||||
run_tests_on_kit(kit=RetrievalKitClass(capsule, set()))
|
||||
|
||||
with pytest.raises(InvalidInputData):
|
||||
|
|
|
@ -1,236 +0,0 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import re
|
||||
from typing import Tuple, Any, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
def _check_valid_version_tuple(version: Any, cls: Type):
|
||||
if not isinstance(version, tuple):
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
|
||||
if not len(version) == Versioned._VERSION_PARTS:
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._VERSION_PARTS)}-tuple")
|
||||
if not all(isinstance(part, int) for part in version):
|
||||
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
|
||||
|
||||
|
||||
class A(Versioned):
|
||||
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
@classmethod
|
||||
def _brand(cls):
|
||||
return b"ABCD"
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 2, 1
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return self.x.to_bytes(1, 'big')
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls):
|
||||
return {
|
||||
(2, 0): cls._from_bytes_v2_0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_v2_0(cls, data):
|
||||
# v2.0 saved a 4 byte integer in hex format
|
||||
int_hex, remainder = data[:2], data[2:]
|
||||
int_bytes = bytes.fromhex(int_hex.decode())
|
||||
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
# v2.1 saves a 4 byte integer as 4 bytes
|
||||
int_bytes, remainder = data[:1], data[1:]
|
||||
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
||||
|
||||
|
||||
def test_unique_branding():
|
||||
brands = tuple(v._brand() for v in Versioned.__subclasses__())
|
||||
brands_set = set(brands)
|
||||
if len(brands) != len(brands_set):
|
||||
duplicate_brands = list(brands)
|
||||
for brand in brands_set:
|
||||
duplicate_brands.remove(brand)
|
||||
pytest.fail(f"Duplicated brand(s) {duplicate_brands}.")
|
||||
|
||||
|
||||
def test_valid_branding():
|
||||
for cls in Versioned.__subclasses__():
|
||||
if len(cls._brand()) != cls._BRAND_SIZE:
|
||||
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
|
||||
if not re.fullmatch(rb'\w+', cls._brand()):
|
||||
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
|
||||
|
||||
def test_valid_version_implementation():
|
||||
for cls in Versioned.__subclasses__():
|
||||
_check_valid_version_tuple(version=cls._version(), cls=cls)
|
||||
|
||||
|
||||
def test_valid_old_handlers_index():
|
||||
for cls in Versioned.__subclasses__():
|
||||
for version in cls._deserializers():
|
||||
_check_valid_version_tuple(version=version, cls=cls)
|
||||
|
||||
|
||||
def test_version_metadata():
|
||||
major, minor = A._version()
|
||||
assert A.version_string() == f'{major}.{minor}'
|
||||
|
||||
|
||||
def test_versioning_header_prepend():
|
||||
a = A(1) # stake sauce
|
||||
assert a.x == 1
|
||||
|
||||
serialized = bytes(a)
|
||||
assert len(serialized) > Versioned._HEADER_SIZE
|
||||
|
||||
header = serialized[:Versioned._HEADER_SIZE]
|
||||
brand = header[:Versioned._BRAND_SIZE]
|
||||
assert brand == A._brand()
|
||||
|
||||
version = header[Versioned._BRAND_SIZE:]
|
||||
major, minor = version[:Versioned._VERSION_PART_SIZE], version[Versioned._VERSION_PART_SIZE:]
|
||||
major_number = int.from_bytes(major, 'big')
|
||||
minor_number = int.from_bytes(minor, 'big')
|
||||
assert (major_number, minor_number) == A._version()
|
||||
|
||||
|
||||
def test_versioning_input_too_short():
|
||||
empty = b'ABCD\x00\x01'
|
||||
with pytest.raises(ValueError, match='Invalid bytes for A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_empty_payload():
|
||||
empty = b'ABCD\x00\x02\x00\x01'
|
||||
with pytest.raises(ValueError, match='No content to deserialize A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_invalid_brand():
|
||||
invalid = b'\x01\x02\x00\x03\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
||||
A.from_bytes(invalid)
|
||||
|
||||
# A partially invalid brand, to check that the regexp validates
|
||||
# the whole brand and not just the beginning of it.
|
||||
invalid = b'ABC \x00\x02\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
||||
A.from_bytes(invalid)
|
||||
|
||||
|
||||
def test_versioning_incorrect_brand():
|
||||
incorrect = b'ABAB\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'ABCD', Got b'ABAB'."):
|
||||
A.from_bytes(incorrect)
|
||||
|
||||
|
||||
def test_unknown_future_major_version():
|
||||
empty = b'ABCD\x00\x03\x00\x0212'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_incompatible_old_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'ABCD\x00\x01\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_incompatible_future_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'ABCD\x00\x03\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_resolve_version():
|
||||
# past
|
||||
v2_0 = 2, 0
|
||||
resolved_version = A._resolve_version(version=v2_0)
|
||||
assert resolved_version == v2_0
|
||||
|
||||
# present
|
||||
v2_1 = 2, 1
|
||||
resolved_version = A._resolve_version(version=v2_1)
|
||||
assert resolved_version == v2_1
|
||||
|
||||
# future minor version resolves to the latest minor version.
|
||||
v2_2 = 2, 2
|
||||
resolved_version = A._resolve_version(version=v2_2)
|
||||
assert resolved_version == v2_1
|
||||
|
||||
|
||||
def test_old_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
# Old minor version
|
||||
v2_0_data = b'ABCD\x00\x02\x00\x0012'
|
||||
a = A.from_bytes(v2_0_data)
|
||||
assert a.x == 18
|
||||
|
||||
# Old minor version was correctly routed to the v2.0 handler.
|
||||
assert v2_0_spy.call_count == 1
|
||||
v2_0_spy.assert_called_with(b'12')
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_current_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
|
||||
a = A.from_bytes(v2_1_data)
|
||||
assert a.x == 18
|
||||
|
||||
# Current version was correctly routed to the v2.1 handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'\x12')
|
||||
assert not v2_0_spy.call_count
|
||||
|
||||
|
||||
def test_future_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
|
||||
a = A.from_bytes(v2_2_data)
|
||||
assert a.x == 18
|
||||
|
||||
# Future minor version was correctly routed to
|
||||
# the current minor version handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'\x12')
|
||||
assert not v2_0_spy.call_count
|
|
@ -26,7 +26,7 @@ from nucypher.blockchain.eth.interfaces import BlockchainInterface
|
|||
from nucypher.characters.lawful import Bob
|
||||
from nucypher.characters.lawful import Ursula
|
||||
from nucypher.config.characters import UrsulaConfiguration
|
||||
from nucypher.crypto.umbral_adapter import SecretKey, Signer, encrypt, generate_kfrags, reencrypt
|
||||
from nucypher.crypto.umbral_adapter import SecretKey, Signer, generate_kfrags
|
||||
from tests.constants import NUMBER_OF_URSULAS_IN_DEVELOPMENT_NETWORK
|
||||
from tests.mock.datastore import MOCK_DB
|
||||
|
||||
|
|
Loading…
Reference in New Issue