mirror of https://github.com/nucypher/nucypher.git
Merge pull request #2767 from KPrasch/versioning
Versioning of bytes serializable protocol entitiespull/2804/head
commit
a9b2a8d412
|
@ -0,0 +1 @@
|
|||
Uniform versioning of bytes serializable protocol entities.
|
|
@ -17,7 +17,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
|
||||
from collections import defaultdict
|
||||
import random
|
||||
from typing import Dict, Sequence, List
|
||||
from typing import Dict, Sequence, List, Tuple
|
||||
|
||||
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
|
||||
from eth_typing.evm import ChecksumAddress
|
||||
|
@ -43,6 +43,7 @@ from nucypher.network.nodes import Learner
|
|||
from nucypher.policy.hrac import HRAC, hrac_splitter
|
||||
from nucypher.policy.kits import MessageKit, RetrievalKit, RetrievalResult
|
||||
from nucypher.policy.maps import TreasureMap
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
class RetrievalPlan:
|
||||
|
@ -138,16 +139,11 @@ class RetrievalWorkOrder:
|
|||
self.capsules = capsules
|
||||
|
||||
|
||||
class ReencryptionRequest:
|
||||
class ReencryptionRequest(Versioned):
|
||||
"""
|
||||
A request for an Ursula to reencrypt for several capsules.
|
||||
"""
|
||||
|
||||
_splitter = (hrac_splitter +
|
||||
key_splitter +
|
||||
key_splitter +
|
||||
BytestringSplitter((MessageKit, VariableLengthBytestring)))
|
||||
|
||||
@classmethod
|
||||
def from_work_order(cls,
|
||||
work_order: RetrievalWorkOrder,
|
||||
|
@ -159,8 +155,7 @@ class ReencryptionRequest:
|
|||
alice_verifying_key=alice_verifying_key,
|
||||
bob_verifying_key=bob_verifying_key,
|
||||
encrypted_kfrag=treasure_map.destinations[work_order.ursula_address],
|
||||
capsules=work_order.capsules,
|
||||
)
|
||||
capsules=work_order.capsules)
|
||||
|
||||
def __init__(self,
|
||||
hrac: HRAC,
|
||||
|
@ -175,20 +170,6 @@ class ReencryptionRequest:
|
|||
self.encrypted_kfrag = encrypted_kfrag
|
||||
self.capsules = capsules
|
||||
|
||||
def __bytes__(self):
|
||||
return (bytes(self.hrac) +
|
||||
bytes(self._alice_verifying_key) +
|
||||
bytes(self._bob_verifying_key) +
|
||||
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
|
||||
b''.join(bytes(capsule) for capsule in self.capsules)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
hrac, alice_vk, bob_vk, ekfrag, remainder = cls._splitter(data, return_remainder=True)
|
||||
capsules = capsule_splitter.repeat(remainder)
|
||||
return cls(hrac, alice_vk, bob_vk, ekfrag, capsules)
|
||||
|
||||
def alice(self) -> 'Alice':
|
||||
from nucypher.characters.lawful import Alice
|
||||
return Alice.from_public_keys(verifying_key=self._alice_verifying_key)
|
||||
|
@ -201,8 +182,39 @@ class ReencryptionRequest:
|
|||
from nucypher.characters.lawful import Alice
|
||||
return Alice.from_public_keys(verifying_key=self.encrypted_kfrag.sender_verifying_key)
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return (bytes(self.hrac) +
|
||||
bytes(self._alice_verifying_key) +
|
||||
bytes(self._bob_verifying_key) +
|
||||
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
|
||||
b''.join(bytes(capsule) for capsule in self.capsules)
|
||||
)
|
||||
|
||||
class ReencryptionResponse:
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RQ'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = (hrac_splitter +
|
||||
key_splitter +
|
||||
key_splitter +
|
||||
BytestringSplitter((MessageKit, VariableLengthBytestring)))
|
||||
|
||||
hrac, alice_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True)
|
||||
capsules = capsule_splitter.repeat(remainder)
|
||||
return cls(hrac, alice_vk, bob_vk, ekfrag, capsules)
|
||||
|
||||
|
||||
class ReencryptionResponse(Versioned):
|
||||
"""
|
||||
A response from Ursula with reencrypted capsule frags.
|
||||
"""
|
||||
|
@ -226,21 +238,34 @@ class ReencryptionResponse:
|
|||
self.cfrags = cfrags
|
||||
self.signature = signature
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RR'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
signature, cfrags_bytes = signature_splitter(data, return_remainder=True)
|
||||
|
||||
# We would never send a request with no capsules, so there should be cfrags.
|
||||
# The splitter would fail anyway, this just makes the error message more clear.
|
||||
if not cfrags_bytes:
|
||||
raise ValueError("ReencryptionResponse contains no cfrags")
|
||||
raise ValueError(f"{cls.__name__} contains no cfrags")
|
||||
|
||||
cfrags = cfrag_splitter.repeat(cfrags_bytes)
|
||||
return cls(cfrags, signature)
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
|
||||
|
||||
|
||||
class RetrievalClient:
|
||||
"""
|
||||
|
|
|
@ -40,7 +40,7 @@ from nucypher.network.protocols import InterfaceInfo
|
|||
from nucypher.network.retrieval import ReencryptionRequest, ReencryptionResponse
|
||||
from nucypher.policy.hrac import HRAC
|
||||
from nucypher.policy.kits import MessageKit
|
||||
from nucypher.policy.revocation import Revocation
|
||||
from nucypher.policy.revocation import RevocationOrder
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
HERE = BASE_DIR = Path(__file__).parent
|
||||
|
@ -272,7 +272,7 @@ def _make_rest_app(datastore: Datastore, this_node, domain: str, log: Logger) ->
|
|||
|
||||
@rest_app.route('/revoke', methods=['POST'])
|
||||
def revoke():
|
||||
revocation = Revocation.from_bytes(request.data)
|
||||
revocation = RevocationOrder.from_bytes(request.data)
|
||||
# TODO: Implement offchain revocation.
|
||||
return Response(status=200)
|
||||
|
||||
|
|
|
@ -16,30 +16,30 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
|
||||
from typing import Dict, Optional, Iterable, Set
|
||||
from typing import Dict, Optional, Iterable, Set, Tuple
|
||||
|
||||
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
|
||||
from constant_sorrow.constants import (
|
||||
NOT_SIGNED,
|
||||
DO_NOT_SIGN,
|
||||
SIGNATURE_TO_FOLLOW,
|
||||
SIGNATURE_IS_ON_CIPHERTEXT,
|
||||
NOT_SIGNED,
|
||||
)
|
||||
)
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address, to_canonical_address
|
||||
|
||||
import nucypher.crypto.umbral_adapter as umbral # need it to mock `umbral.encrypt`
|
||||
from nucypher.crypto.splitters import (
|
||||
capsule_splitter,
|
||||
key_splitter,
|
||||
signature_splitter,
|
||||
checksum_address_splitter,
|
||||
)
|
||||
import nucypher.crypto.umbral_adapter as umbral # need it to mock `umbral.encrypt`
|
||||
)
|
||||
from nucypher.crypto.umbral_adapter import PublicKey, VerifiedCapsuleFrag, Capsule, Signature
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
class MessageKit:
|
||||
class MessageKit(Versioned):
|
||||
"""
|
||||
All the components needed to transmit and verify an encrypted message.
|
||||
"""
|
||||
|
@ -106,7 +106,13 @@ class MessageKit:
|
|||
def __str__(self):
|
||||
return f"{self.__class__.__name__}({self.capsule})"
|
||||
|
||||
def __bytes__(self):
|
||||
def as_policy_kit(self, policy_key: PublicKey, threshold: int) -> 'PolicyMessageKit':
|
||||
return PolicyMessageKit.from_message_kit(self, policy_key, threshold)
|
||||
|
||||
def as_retrieval_kit(self) -> 'RetrievalKit':
|
||||
return RetrievalKit(self.capsule, set())
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
# TODO (#2743): this logic may not be necessary depending on the resolution.
|
||||
# If it is, it is better moved to BytestringSplitter.
|
||||
return (bytes(self.capsule) +
|
||||
|
@ -115,7 +121,19 @@ class MessageKit:
|
|||
VariableLengthBytestring(self.ciphertext))
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
def _brand(cls) -> bytes:
|
||||
return b'MK'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(
|
||||
capsule_splitter,
|
||||
(bytes, 1))
|
||||
|
@ -129,8 +147,7 @@ class MessageKit:
|
|||
else:
|
||||
raise ValueError("Incorrect format for the signature flag")
|
||||
|
||||
splitter = BytestringSplitter(
|
||||
(bytes, 1))
|
||||
splitter = BytestringSplitter((bytes, 1))
|
||||
|
||||
key_flag, remainder = splitter(remainder, return_remainder=True)
|
||||
|
||||
|
@ -145,14 +162,8 @@ class MessageKit:
|
|||
|
||||
return cls(capsule, ciphertext, signature=signature, sender_verifying_key=sender_verifying_key)
|
||||
|
||||
def as_policy_kit(self, policy_key: PublicKey, threshold: int) -> 'PolicyMessageKit':
|
||||
return PolicyMessageKit.from_message_kit(self, policy_key, threshold)
|
||||
|
||||
def as_retrieval_kit(self) -> 'RetrievalKit':
|
||||
return RetrievalKit(self.capsule, set())
|
||||
|
||||
|
||||
class RetrievalKit:
|
||||
class RetrievalKit(Versioned):
|
||||
"""
|
||||
An object encapsulating the information necessary for retrieval of cfrags from Ursulas.
|
||||
Contains the capsule and the checksum addresses of Ursulas from which the requester
|
||||
|
@ -164,12 +175,24 @@ class RetrievalKit:
|
|||
# Can store cfrags too, if we're worried about Ursulas supplying duplicate ones.
|
||||
self.queried_addresses = set(queried_addresses)
|
||||
|
||||
def __bytes__(self):
|
||||
def _payload(self) -> bytes:
|
||||
return (bytes(self.capsule) +
|
||||
b''.join(to_canonical_address(address) for address in self.queried_addresses))
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RK'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
capsule, remainder = capsule_splitter(data, return_remainder=True)
|
||||
if remainder:
|
||||
addresses_as_bytes = checksum_address_splitter.repeat(remainder)
|
||||
|
|
|
@ -16,7 +16,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
|
||||
from typing import Optional, Callable, Sequence, Dict
|
||||
from typing import Optional, Callable, Sequence, Dict, Tuple
|
||||
|
||||
from bytestring_splitter import (
|
||||
BytestringSplitter,
|
||||
|
@ -36,24 +36,40 @@ from nucypher.crypto.utils import keccak_digest, verify_eip_191
|
|||
from nucypher.network.middleware import RestMiddleware
|
||||
from nucypher.policy.hrac import HRAC, hrac_splitter
|
||||
from nucypher.policy.kits import MessageKit
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
class TreasureMap:
|
||||
class TreasureMap(Versioned):
|
||||
|
||||
class NowhereToBeFound(RestMiddleware.NotFound):
|
||||
"""
|
||||
Called when no known nodes have it.
|
||||
"""
|
||||
|
||||
main_splitter = BytestringSplitter(
|
||||
(int, 1, {'byteorder': 'big'}),
|
||||
hrac_splitter,
|
||||
)
|
||||
def __init__(self,
|
||||
threshold: int,
|
||||
hrac: HRAC,
|
||||
destinations: Dict[ChecksumAddress, MessageKit]):
|
||||
self.threshold = threshold
|
||||
self.destinations = destinations
|
||||
self.hrac = hrac
|
||||
|
||||
ursula_and_kfrag_payload_splitter = BytestringSplitter(
|
||||
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
|
||||
(MessageKit, VariableLengthBytestring),
|
||||
)
|
||||
# A little awkward, but saves us a key length in serialization
|
||||
self.publisher_verifying_key = list(destinations.values())[0].sender_verifying_key
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.destinations.items())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.destinations)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TreasureMap):
|
||||
return False
|
||||
|
||||
return (self.threshold == other.threshold and
|
||||
self.hrac == other.hrac and
|
||||
self.destinations == other.destinations)
|
||||
|
||||
@classmethod
|
||||
def construct_by_publisher(cls,
|
||||
|
@ -87,17 +103,42 @@ class TreasureMap:
|
|||
|
||||
return cls(threshold=threshold, hrac=hrac, destinations=destinations)
|
||||
|
||||
def __init__(self,
|
||||
threshold: int,
|
||||
hrac: HRAC,
|
||||
destinations: Dict[ChecksumAddress, MessageKit],
|
||||
):
|
||||
self.threshold = threshold
|
||||
self.destinations = destinations
|
||||
self.hrac = hrac
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'TM'
|
||||
|
||||
# A little awkward, but saves us a key length in serialization
|
||||
self.publisher_verifying_key = list(destinations.values())[0].sender_verifying_key
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return self.threshold.to_bytes(1, "big") + bytes(self.hrac) + self._nodes_as_bytes()
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
|
||||
main_splitter = BytestringSplitter(
|
||||
(int, 1, {'byteorder': 'big'}),
|
||||
hrac_splitter,
|
||||
)
|
||||
|
||||
ursula_and_kfrag_payload_splitter = BytestringSplitter(
|
||||
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
|
||||
(MessageKit, VariableLengthBytestring),
|
||||
)
|
||||
|
||||
try:
|
||||
threshold, hrac, remainder = main_splitter(data, return_remainder=True)
|
||||
ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder)
|
||||
except BytestringSplittingError as e:
|
||||
raise ValueError('Invalid treasure map contents.') from e
|
||||
destinations = {u: k for u, k in ursula_and_kfrags}
|
||||
return cls(threshold, hrac, destinations)
|
||||
|
||||
def encrypt(self,
|
||||
publisher: 'Alice',
|
||||
|
@ -117,52 +158,22 @@ class TreasureMap:
|
|||
nodes_as_bytes += (node_id + kfrag)
|
||||
return nodes_as_bytes
|
||||
|
||||
def __bytes__(self):
|
||||
return self.threshold.to_bytes(1, "big") + bytes(self.hrac) + self._nodes_as_bytes()
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
try:
|
||||
threshold, hrac, remainder = cls.main_splitter(data, return_remainder=True)
|
||||
ursula_and_kfrags = cls.ursula_and_kfrag_payload_splitter.repeat(remainder)
|
||||
except BytestringSplittingError as e:
|
||||
raise ValueError('Invalid treasure map contents.') from e
|
||||
destinations = {u: k for u, k in ursula_and_kfrags}
|
||||
return cls(threshold, hrac, destinations)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TreasureMap):
|
||||
return False
|
||||
|
||||
return (self.threshold == other.threshold and
|
||||
self.hrac == other.hrac and
|
||||
self.destinations == other.destinations)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.destinations.items())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.destinations)
|
||||
|
||||
|
||||
class AuthorizedKeyFrag:
|
||||
class AuthorizedKeyFrag(Versioned):
|
||||
|
||||
_WRIT_CHECKSUM_SIZE = 32
|
||||
|
||||
# The size of a serialized message kit encrypting an AuthorizedKeyFrag.
|
||||
# Depends on encryption parameters in Umbral, has to be hardcoded.
|
||||
ENCRYPTED_SIZE = 621
|
||||
SERIALIZED_SIZE = Versioned._HEADER_SIZE + ENCRYPTED_SIZE
|
||||
|
||||
_splitter = BytestringSplitter(
|
||||
hrac_splitter, # HRAC
|
||||
(bytes, _WRIT_CHECKSUM_SIZE), # kfrag checksum
|
||||
signature_splitter, # Publisher's signature
|
||||
kfrag_splitter,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _kfrag_checksum(kfrag: KeyFrag) -> bytes:
|
||||
return keccak_digest(bytes(kfrag))[:AuthorizedKeyFrag._WRIT_CHECKSUM_SIZE]
|
||||
def __init__(self, hrac: HRAC, kfrag_checksum: bytes, writ_signature: Signature, kfrag: KeyFrag):
|
||||
self.hrac = hrac
|
||||
self.kfrag_checksum = kfrag_checksum
|
||||
self.writ = bytes(hrac) + kfrag_checksum
|
||||
self.writ_signature = writ_signature
|
||||
self.kfrag = kfrag
|
||||
|
||||
@classmethod
|
||||
def construct_by_publisher(cls,
|
||||
|
@ -184,20 +195,38 @@ class AuthorizedKeyFrag:
|
|||
# the material needed for Ursula to assuredly service this policy.
|
||||
return cls(hrac, kfrag_checksum, writ_signature, kfrag)
|
||||
|
||||
def __init__(self, hrac: HRAC, kfrag_checksum: bytes, writ_signature: Signature, kfrag: KeyFrag):
|
||||
self.hrac = hrac
|
||||
self.kfrag_checksum = kfrag_checksum
|
||||
self.writ = bytes(hrac) + kfrag_checksum
|
||||
self.writ_signature = writ_signature
|
||||
self.kfrag = kfrag
|
||||
@staticmethod
|
||||
def _kfrag_checksum(kfrag: KeyFrag) -> bytes:
|
||||
return keccak_digest(bytes(kfrag))[:AuthorizedKeyFrag._WRIT_CHECKSUM_SIZE]
|
||||
|
||||
def __bytes__(self):
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return self.writ + bytes(self.writ_signature) + bytes(self.kfrag)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
def _brand(cls) -> bytes:
|
||||
return b'KF'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
# TODO: should we check the signature right away here?
|
||||
hrac, kfrag_checksum, writ_signature, kfrag = cls._splitter(data)
|
||||
|
||||
splitter = BytestringSplitter(
|
||||
hrac_splitter, # HRAC
|
||||
(bytes, cls._WRIT_CHECKSUM_SIZE), # kfrag checksum
|
||||
signature_splitter, # Publisher's signature
|
||||
kfrag_splitter,
|
||||
)
|
||||
|
||||
hrac, kfrag_checksum, writ_signature, kfrag = splitter(data)
|
||||
|
||||
# Check integrity
|
||||
calculated_checksum = cls._kfrag_checksum(kfrag)
|
||||
|
@ -207,19 +236,24 @@ class AuthorizedKeyFrag:
|
|||
return cls(hrac, kfrag_checksum, writ_signature, kfrag)
|
||||
|
||||
|
||||
class EncryptedTreasureMap:
|
||||
|
||||
_splitter = BytestringSplitter(
|
||||
signature_splitter, # public signature
|
||||
hrac_splitter, # HRAC
|
||||
(MessageKit, VariableLengthBytestring), # encrypted TreasureMap
|
||||
(bytes, EIP712_MESSAGE_SIGNATURE_SIZE)) # blockchain signature
|
||||
class EncryptedTreasureMap(Versioned):
|
||||
|
||||
_EMPTY_BLOCKCHAIN_SIGNATURE = b'\x00' * EIP712_MESSAGE_SIGNATURE_SIZE
|
||||
|
||||
# TODO: do we really need this alias?
|
||||
from nucypher.crypto.signing import \
|
||||
InvalidSignature # Raised when the public signature (typically intended for Ursula) is not valid.
|
||||
# Raised when the public signature (typically intended for Ursula) is not valid.
|
||||
from nucypher.crypto.signing import InvalidSignature
|
||||
|
||||
def __init__(self,
|
||||
hrac: HRAC,
|
||||
public_signature: Signature,
|
||||
encrypted_tmap: MessageKit,
|
||||
blockchain_signature: Optional[bytes] = None):
|
||||
|
||||
self.hrac = hrac
|
||||
self._public_signature = public_signature
|
||||
self.publisher_verifying_key = encrypted_tmap.sender_verifying_key
|
||||
self._encrypted_tmap = encrypted_tmap
|
||||
self._blockchain_signature = blockchain_signature
|
||||
|
||||
@staticmethod
|
||||
def _sign(blockchain_signer: Callable[[bytes], bytes],
|
||||
|
@ -258,19 +292,6 @@ class EncryptedTreasureMap:
|
|||
|
||||
return cls(treasure_map.hrac, public_signature, encrypted_tmap, blockchain_signature=blockchain_signature)
|
||||
|
||||
def __init__(self,
|
||||
hrac: HRAC,
|
||||
public_signature: Signature,
|
||||
encrypted_tmap: MessageKit,
|
||||
blockchain_signature: Optional[bytes] = None,
|
||||
):
|
||||
|
||||
self.hrac = hrac
|
||||
self._public_signature = public_signature
|
||||
self.publisher_verifying_key = encrypted_tmap.sender_verifying_key
|
||||
self._encrypted_tmap = encrypted_tmap
|
||||
self._blockchain_signature = blockchain_signature
|
||||
|
||||
def decrypt(self, decryptor: Callable[[bytes], bytes]) -> TreasureMap:
|
||||
"""
|
||||
When Bob receives the TreasureMap, he'll pass a decryptor (a callable which can verify and decrypt the
|
||||
|
@ -284,17 +305,6 @@ class EncryptedTreasureMap:
|
|||
|
||||
return TreasureMap.from_bytes(map_in_the_clear)
|
||||
|
||||
def __bytes__(self):
|
||||
if self._blockchain_signature:
|
||||
signature = self._blockchain_signature
|
||||
else:
|
||||
signature = self._EMPTY_BLOCKCHAIN_SIGNATURE
|
||||
return (bytes(self._public_signature) +
|
||||
bytes(self.hrac) +
|
||||
bytes(VariableLengthBytestring(bytes(self._encrypted_tmap))) +
|
||||
signature
|
||||
)
|
||||
|
||||
def verify_blockchain_signature(self, checksum_address: ChecksumAddress) -> bool:
|
||||
if self._blockchain_signature is None:
|
||||
raise ValueError("This EncryptedTreasureMap is not blockchain-signed")
|
||||
|
@ -308,10 +318,39 @@ class EncryptedTreasureMap:
|
|||
if not self._public_signature.verify(self.publisher_verifying_key, message=message):
|
||||
raise self.InvalidSignature("This TreasureMap is not properly publicly signed by the publisher.")
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
if self._blockchain_signature:
|
||||
signature = self._blockchain_signature
|
||||
else:
|
||||
signature = self._EMPTY_BLOCKCHAIN_SIGNATURE
|
||||
return (bytes(self._public_signature) +
|
||||
bytes(self.hrac) +
|
||||
bytes(VariableLengthBytestring(bytes(self._encrypted_tmap))) +
|
||||
signature)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
def _brand(cls) -> bytes:
|
||||
return b'EM'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
|
||||
splitter = BytestringSplitter(
|
||||
signature_splitter, # public signature
|
||||
hrac_splitter, # HRAC
|
||||
(MessageKit, VariableLengthBytestring), # encrypted TreasureMap
|
||||
(bytes, EIP712_MESSAGE_SIGNATURE_SIZE)) # blockchain signature
|
||||
|
||||
try:
|
||||
public_signature, hrac, message_kit, blockchain_signature = cls._splitter(data)
|
||||
public_signature, hrac, message_kit, blockchain_signature = splitter(data)
|
||||
if blockchain_signature == cls._EMPTY_BLOCKCHAIN_SIGNATURE:
|
||||
blockchain_signature = None
|
||||
except BytestringSplittingError as e:
|
||||
|
|
|
@ -38,24 +38,18 @@ from nucypher.policy.reservoir import (
|
|||
from nucypher.policy.revocation import RevocationKit
|
||||
from nucypher.utilities.concurrency import WorkerPool
|
||||
from nucypher.utilities.logging import Logger
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
class Arrangement:
|
||||
"""
|
||||
A contract between Alice and a single Ursula.
|
||||
"""
|
||||
|
||||
splitter = BytestringSplitter(
|
||||
key_splitter, # publisher_verifying_key
|
||||
(bytes, VariableLengthBytestring) # expiration
|
||||
)
|
||||
class Arrangement(Versioned):
|
||||
"""A contract between Alice and a single Ursula."""
|
||||
|
||||
def __init__(self, publisher_verifying_key: PublicKey, expiration: maya.MayaDT):
|
||||
self.expiration = expiration
|
||||
self.publisher_verifying_key = publisher_verifying_key
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.publisher_verifying_key) + bytes(VariableLengthBytestring(self.expiration.iso8601().encode()))
|
||||
def __repr__(self):
|
||||
return f"Arrangement(publisher={self.publisher_verifying_key})"
|
||||
|
||||
@classmethod
|
||||
def from_publisher(cls, publisher: 'Alice', expiration: maya.MayaDT) -> 'Arrangement':
|
||||
|
@ -63,14 +57,31 @@ class Arrangement:
|
|||
return cls(publisher_verifying_key=publisher_verifying_key, expiration=expiration)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, arrangement_as_bytes: bytes) -> 'Arrangement':
|
||||
publisher_verifying_key, expiration_bytes = cls.splitter(arrangement_as_bytes)
|
||||
def _brand(cls) -> bytes:
|
||||
return b'AR'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return bytes(self.publisher_verifying_key) + bytes(VariableLengthBytestring(self.expiration.iso8601().encode()))
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data: bytes):
|
||||
splitter = BytestringSplitter(
|
||||
key_splitter, # publisher_verifying_key
|
||||
(bytes, VariableLengthBytestring) # expiration
|
||||
)
|
||||
publisher_verifying_key, expiration_bytes = splitter(data)
|
||||
expiration = maya.MayaDT.from_iso8601(iso8601_string=expiration_bytes.decode())
|
||||
return cls(publisher_verifying_key=publisher_verifying_key, expiration=expiration)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Arrangement(publisher={self.publisher_verifying_key})"
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
"""
|
||||
|
|
|
@ -16,7 +16,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Tuple
|
||||
|
||||
from bytestring_splitter import BytestringSplitter
|
||||
from eth_typing.evm import ChecksumAddress
|
||||
|
@ -27,9 +27,10 @@ from nucypher.crypto.splitters import signature_splitter, checksum_address_split
|
|||
from nucypher.crypto.umbral_adapter import Signature, PublicKey
|
||||
from nucypher.policy.kits import MessageKit
|
||||
from nucypher.policy.maps import AuthorizedKeyFrag
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
class Revocation:
|
||||
class RevocationOrder(Versioned):
|
||||
"""
|
||||
Represents a string used by characters to perform a revocation on a specific
|
||||
Ursula. It's a bytestring made of the following format:
|
||||
|
@ -37,16 +38,9 @@ class Revocation:
|
|||
This is sent as a payload in a DELETE method to the /KFrag/ endpoint.
|
||||
"""
|
||||
|
||||
PREFIX = b'REVOKE-'
|
||||
revocation_splitter = BytestringSplitter(
|
||||
(bytes, len(PREFIX)),
|
||||
checksum_address_splitter, # ursula canonical address
|
||||
(bytes, AuthorizedKeyFrag.ENCRYPTED_SIZE), # encrypted kfrag payload (includes writ)
|
||||
signature_splitter
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
ursula_checksum_address: ChecksumAddress, # TODO: Use staker address instead (what if the staker rebonds)?
|
||||
ursula_checksum_address: ChecksumAddress,
|
||||
# TODO: Use staker address instead (what if the staker rebonds)?
|
||||
encrypted_kfrag: MessageKit,
|
||||
signer: Optional[SignatureStamp] = None,
|
||||
signature: Optional[Signature] = None):
|
||||
|
@ -57,13 +51,10 @@ class Revocation:
|
|||
if not (bool(signer) ^ bool(signature)):
|
||||
raise ValueError("Either pass a signer or a signature; not both.")
|
||||
elif signer:
|
||||
self.signature = signer(self.payload)
|
||||
self.signature = signer(self._body())
|
||||
elif signature:
|
||||
self.signature = signature
|
||||
|
||||
def __bytes__(self):
|
||||
return self.payload + bytes(self.signature)
|
||||
|
||||
def __repr__(self):
|
||||
return bytes(self)
|
||||
|
||||
|
@ -73,37 +64,55 @@ class Revocation:
|
|||
def __eq__(self, other):
|
||||
return bytes(self) == bytes(other)
|
||||
|
||||
@property
|
||||
def payload(self):
|
||||
return self.PREFIX \
|
||||
+ to_canonical_address(self.ursula_checksum_address) \
|
||||
+ bytes(self.encrypted_kfrag) \
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, revocation_bytes):
|
||||
prefix, ursula_canonical_address, ekfrag, signature = cls.revocation_splitter(revocation_bytes)
|
||||
ursula_checksum_address = to_checksum_address(ursula_canonical_address)
|
||||
return cls(ursula_checksum_address=ursula_checksum_address,
|
||||
encrypted_kfrag=ekfrag,
|
||||
signature=signature)
|
||||
|
||||
def verify_signature(self, alice_verifying_key: PublicKey) -> bool:
|
||||
"""
|
||||
Verifies the revocation was from the provided pubkey.
|
||||
"""
|
||||
if not self.signature.verify(self.payload, alice_verifying_key):
|
||||
if not self.signature.verify(self._body(), alice_verifying_key):
|
||||
raise InvalidSignature(f"Revocation has an invalid signature: {self.signature}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RV'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 1, 0
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls) -> Dict:
|
||||
return {}
|
||||
|
||||
def _body(self) -> bytes:
|
||||
return to_canonical_address(self.ursula_checksum_address) + bytes(self.encrypted_kfrag)
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return self._body() + bytes(self.signature)
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
|
||||
splitter = BytestringSplitter(
|
||||
checksum_address_splitter, # ursula canonical address
|
||||
(bytes, Versioned._HEADER_SIZE+AuthorizedKeyFrag.SERIALIZED_SIZE), # MessageKit version header + versioned ekfrag
|
||||
signature_splitter
|
||||
)
|
||||
ursula_canonical_address, ekfrag, signature = splitter(data)
|
||||
ursula_checksum_address = to_checksum_address(ursula_canonical_address)
|
||||
return cls(ursula_checksum_address=ursula_checksum_address,
|
||||
encrypted_kfrag=ekfrag,
|
||||
signature=signature)
|
||||
|
||||
|
||||
class RevocationKit:
|
||||
|
||||
def __init__(self, treasure_map, signer: SignatureStamp):
|
||||
self.revocations = dict()
|
||||
for node_id, encrypted_kfrag in treasure_map:
|
||||
self.revocations[node_id] = Revocation(ursula_checksum_address=node_id,
|
||||
encrypted_kfrag=encrypted_kfrag,
|
||||
signer=signer)
|
||||
self.revocations[node_id] = RevocationOrder(ursula_checksum_address=node_id,
|
||||
encrypted_kfrag=encrypted_kfrag,
|
||||
signer=signer)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.revocations.values())
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Dict, Tuple, Callable
|
||||
|
||||
|
||||
class Versioned(ABC):
|
||||
"""Base class for serializable entities"""
|
||||
|
||||
_PARTS = 2 # bytes
|
||||
_PART_SIZE = 2
|
||||
_BRAND_SIZE = 2
|
||||
_VERSION_SIZE = _PART_SIZE * _PARTS
|
||||
_HEADER_SIZE = _BRAND_SIZE + _VERSION_SIZE
|
||||
|
||||
class InvalidHeader(ValueError):
|
||||
"""Raised when an unexpected or invalid bytes header is encountered."""
|
||||
|
||||
class IncompatibleVersion(ValueError):
|
||||
"""Raised when attempting to deserialize incompatible bytes"""
|
||||
|
||||
class Empty(ValueError):
|
||||
"""Raised when 0 bytes are remaining after parsing the header."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _brand(cls) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
"""tuple(major, minor)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version_string(cls) -> str:
|
||||
major, minor = cls._version()
|
||||
return f'{major}.{minor}'
|
||||
|
||||
#
|
||||
# Serialize
|
||||
#
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self._header() + self._payload()
|
||||
|
||||
@classmethod
|
||||
def _header(cls) -> bytes:
|
||||
"""The entire bytes header to prepend to the instance payload."""
|
||||
major, minor = cls._version()
|
||||
major_bytes = major.to_bytes(cls._PART_SIZE, 'big')
|
||||
minor_bytes = minor.to_bytes(cls._PART_SIZE, 'big')
|
||||
header = cls._brand() + major_bytes + minor_bytes
|
||||
return header
|
||||
|
||||
@abstractmethod
|
||||
def _payload(self) -> bytes:
|
||||
"""The unbranded and unversioned bytes-serialized representation of this instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Deserialize
|
||||
#
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
"""The current deserializer"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _old_version_handlers(cls) -> Dict[Tuple[int, int], Callable]:
|
||||
"""Old deserializer callables keyed by version."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
""""Public deserialization API"""
|
||||
brand, version, payload = cls._parse_header(data)
|
||||
version = cls._resolve_version(version=version)
|
||||
handlers = cls._deserializers()
|
||||
return handlers[version](payload)
|
||||
|
||||
@classmethod
|
||||
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:
|
||||
|
||||
# Unpack version metadata
|
||||
bytrestring_major, bytrestring_minor = version
|
||||
latest_major_version, latest_minor_version = cls._version()
|
||||
|
||||
# Enforce major version compatibility
|
||||
if not bytrestring_major == latest_major_version:
|
||||
message = f'Incompatible versioned bytes for {cls.__name__}. ' \
|
||||
f'Compatible version is {latest_major_version}.x, ' \
|
||||
f'Got {bytrestring_major}.{bytrestring_minor}.'
|
||||
raise cls.IncompatibleVersion(message)
|
||||
|
||||
# Enforce minor version compatibility.
|
||||
# Pass future minor versions to the latest minor handler.
|
||||
if bytrestring_minor >= latest_minor_version:
|
||||
version = cls._version()
|
||||
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
def _parse_header(cls, data: bytes) -> Tuple[bytes, Tuple[int, int], bytes]:
|
||||
if len(data) < cls._HEADER_SIZE:
|
||||
# handles edge case when input is too short.
|
||||
raise ValueError(f'Invalid bytes for {cls.__name__}.')
|
||||
brand = cls._parse_brand(data)
|
||||
version = cls._parse_version(data)
|
||||
payload = cls._parse_payload(data)
|
||||
return brand, version, payload
|
||||
|
||||
@classmethod
|
||||
def _parse_brand(cls, data: bytes) -> bytes:
|
||||
brand = data[:cls._BRAND_SIZE]
|
||||
if brand != cls._brand():
|
||||
error = f"Incorrect brand. Expected {cls._brand()}, Got {brand}."
|
||||
if not brand.isalpha():
|
||||
# unversioned entities for older versions will most likely land here.
|
||||
error = f"Incompatible bytes for {cls.__name__}."
|
||||
raise cls.InvalidHeader(error)
|
||||
return brand
|
||||
|
||||
@classmethod
|
||||
def _parse_version(cls, data: bytes) -> Tuple[int, int]:
|
||||
version_data = data[cls._BRAND_SIZE:cls._HEADER_SIZE]
|
||||
major, minor = version_data[:cls._PART_SIZE], version_data[cls._PART_SIZE:]
|
||||
major, minor = int.from_bytes(major, 'big'), int.from_bytes(minor, 'big')
|
||||
version = major, minor
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
def _parse_payload(cls, data: bytes) -> bytes:
|
||||
payload = data[cls._HEADER_SIZE:]
|
||||
if len(payload) == 0:
|
||||
raise ValueError(f'No content to deserialize {cls.__name__}.')
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def _deserializers(cls) -> Dict[Tuple[int, int], Callable]:
|
||||
"""Return a dict of all known deserialization handlers for this class keyed by version"""
|
||||
return {cls._version(): cls._from_bytes_current, **cls._old_version_handlers()}
|
||||
|
||||
|
||||
# Collects the brands of every serializable entity, potentially useful for documentation.
|
||||
# SERIALIZABLE_ENTITIES = {v.__class__.__name__: v._brand() for v in Versioned.__subclasses__()}
|
|
@ -149,7 +149,7 @@ def test_retrieve_cfrags(blockchain_porter,
|
|||
cleartext_with_sig_header = blockchain_bob._crypto_power.power_ups(DecryptingPower).keypair.decrypt(policy_message_kit)
|
||||
sig_header, remainder = default_constant_splitter(cleartext_with_sig_header, return_remainder=True)
|
||||
signature_from_kit, cleartext = signature_splitter(remainder, return_remainder=True)
|
||||
assert signature_from_kit.verify(message=cleartext, verifying_key=policy_message_kit.sender_verifying_key)
|
||||
assert signature_from_kit.verify(message=cleartext, verifying_pk=policy_message_kit.sender_verifying_key)
|
||||
assert cleartext == original_message
|
||||
|
||||
#
|
||||
|
|
|
@ -24,7 +24,7 @@ import pytest
|
|||
from nucypher.characters.lawful import Enrico
|
||||
from nucypher.crypto.utils import keccak_digest
|
||||
from nucypher.policy.kits import MessageKit
|
||||
from nucypher.policy.revocation import Revocation
|
||||
from nucypher.policy.revocation import RevocationOrder
|
||||
|
||||
|
||||
def test_federated_grant(federated_alice, federated_bob, federated_ursulas):
|
||||
|
@ -113,7 +113,7 @@ def test_revocation(federated_alice, federated_bob):
|
|||
# Test Revocation deserialization
|
||||
revocation = policy.revocation_kit[node_id]
|
||||
revocation_bytes = bytes(revocation)
|
||||
deserialized_revocation = Revocation.from_bytes(revocation_bytes)
|
||||
deserialized_revocation = RevocationOrder.from_bytes(revocation_bytes)
|
||||
assert deserialized_revocation == revocation
|
||||
|
||||
# Attempt to revoke the new policy
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
"""
|
||||
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
from base64 import b64encode
|
||||
|
||||
import maya
|
||||
import pytest
|
||||
|
@ -29,6 +29,7 @@ from nucypher.control.specifications.base import BaseSchema
|
|||
from nucypher.control.specifications.exceptions import SpecificationError, InvalidInputData, InvalidArgumentCombo
|
||||
from nucypher.crypto.powers import DecryptingPower
|
||||
from nucypher.crypto.umbral_adapter import PublicKey
|
||||
from nucypher.policy.kits import MessageKit
|
||||
from nucypher.policy.kits import MessageKit as MessageKitClass
|
||||
from nucypher.policy.maps import EncryptedTreasureMap as EncryptedTreasureMapClass, TreasureMap as TreasureMapClass
|
||||
|
||||
|
@ -89,16 +90,19 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
assert "Could not parse tmap" in str(e)
|
||||
assert "Invalid base64-encoded string" in str(e)
|
||||
|
||||
base64_header = base64.b64encode(EncryptedTreasureMapClass._header()).decode()
|
||||
|
||||
# valid base64 but invalid treasuremap
|
||||
bad_map = base64_header + "VGhpcyBpcWgb3RhbGx5IG5vdCBhIHRyZWFzdXJlbWg=="
|
||||
with pytest.raises(InvalidInputData) as e:
|
||||
EncryptedTreasureMapsOnly().load({'tmap': "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="})
|
||||
EncryptedTreasureMapsOnly().load({'tmap': bad_map})
|
||||
|
||||
assert "Could not convert input for tmap to an EncryptedTreasureMap" in str(e)
|
||||
assert "Invalid encrypted treasure map contents." in str(e)
|
||||
|
||||
# a valid treasuremap for once...
|
||||
tmap_bytes = bytes(enacted_federated_policy.treasure_map)
|
||||
tmap_b64 = b64encode(tmap_bytes)
|
||||
tmap_b64 = base64.b64encode(tmap_bytes)
|
||||
result = EncryptedTreasureMapsOnly().load({'tmap': tmap_b64.decode()})
|
||||
assert isinstance(result['tmap'], EncryptedTreasureMapClass)
|
||||
|
||||
|
@ -117,8 +121,10 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
assert "Invalid base64-encoded string" in str(e)
|
||||
|
||||
# valid base64 but invalid treasuremap
|
||||
base64_header = base64.b64encode(TreasureMapClass._header()).decode()
|
||||
bad_map = base64_header + "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="
|
||||
with pytest.raises(InvalidInputData) as e:
|
||||
UnenncryptedTreasureMapsOnly().load({'tmap': "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="})
|
||||
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map})
|
||||
|
||||
assert "Could not convert input for tmap to a TreasureMap" in str(e)
|
||||
assert "Invalid treasure map contents." in str(e)
|
||||
|
@ -126,7 +132,7 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
# a valid treasuremap
|
||||
decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map)
|
||||
tmap_bytes = bytes(decrypted_treasure_map)
|
||||
tmap_b64 = b64encode(tmap_bytes).decode()
|
||||
tmap_b64 = base64.b64encode(tmap_bytes).decode()
|
||||
result = UnenncryptedTreasureMapsOnly().load({'tmap': tmap_b64})
|
||||
assert isinstance(result['tmap'], TreasureMapClass)
|
||||
|
||||
|
@ -146,9 +152,10 @@ def test_messagekit_validation(capsule_side_channel):
|
|||
assert "Could not parse mkit" in str(e)
|
||||
assert "Incorrect padding" in str(e)
|
||||
|
||||
# valid base64 but invalid treasuremap
|
||||
# valid base64 but invalid messagekit
|
||||
b64header = base64.b64encode(MessageKit._header()).decode()
|
||||
with pytest.raises(SpecificationError) as e:
|
||||
MessageKitsOnly().load({'mkit': "V3da"})
|
||||
MessageKitsOnly().load({'mkit': b64header + "V3da=="})
|
||||
|
||||
assert "Could not parse mkit" in str(e)
|
||||
assert "Not enough bytes to constitute message types" in str(e)
|
||||
|
@ -156,7 +163,7 @@ def test_messagekit_validation(capsule_side_channel):
|
|||
# test a valid messagekit
|
||||
valid_kit = capsule_side_channel.messages[0][0]
|
||||
kit_bytes = bytes(valid_kit)
|
||||
kit_b64 = b64encode(kit_bytes)
|
||||
kit_b64 = base64.b64encode(kit_bytes)
|
||||
result = MessageKitsOnly().load({'mkit': kit_b64.decode()})
|
||||
assert isinstance(result['mkit'], MessageKitClass)
|
||||
|
||||
|
|
|
@ -15,17 +15,20 @@
|
|||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import datetime
|
||||
import maya
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import maya
|
||||
import pytest
|
||||
import pytest_twisted
|
||||
import requests
|
||||
from bytestring_splitter import BytestringSplittingError
|
||||
from functools import partial
|
||||
from twisted.internet import threads
|
||||
|
||||
from nucypher.policy.policies import Policy
|
||||
from nucypher.policy.policies import Policy, Arrangement
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
from tests.utils.middleware import EvilMiddleWare, NodeIsDownMiddleware
|
||||
from tests.utils.ursula import make_federated_ursulas
|
||||
|
||||
|
@ -109,8 +112,9 @@ def test_huge_treasure_maps_are_rejected(federated_alice, federated_ursulas):
|
|||
|
||||
firstula = list(federated_ursulas)[0]
|
||||
|
||||
header = Arrangement._header()
|
||||
ok_amount = 10 * 1024 # 10k
|
||||
ok_data = os.urandom(ok_amount)
|
||||
ok_data = header + os.urandom(ok_amount)
|
||||
|
||||
with pytest.raises(BytestringSplittingError):
|
||||
federated_alice.network_middleware.upload_arbitrary_data(
|
||||
|
@ -143,8 +147,10 @@ def test_hendrix_handles_content_length_validation(ursula_federated_test_config)
|
|||
node_deployer.catalogServers(node_deployer.hendrix)
|
||||
node_deployer.start()
|
||||
|
||||
header = Arrangement._header()
|
||||
|
||||
def check_node_rejects_large_posts(node):
|
||||
too_much_data = os.urandom(100 * 1024)
|
||||
too_much_data = header + os.urandom(100 * 1024)
|
||||
response = requests.post(
|
||||
"https://{}/consider_arrangement".format(node.rest_url()),
|
||||
data=too_much_data, verify=False)
|
||||
|
@ -153,7 +159,8 @@ def test_hendrix_handles_content_length_validation(ursula_federated_test_config)
|
|||
return node
|
||||
|
||||
def check_node_accepts_normal_posts(node):
|
||||
a_normal_arrangement = os.urandom(49 * 1024) # 49K, the limit is 50K
|
||||
under_limit = (49 * 1024)-Versioned._HEADER_SIZE # 49K, the limit is 50K
|
||||
a_normal_arrangement = header + os.urandom(under_limit)
|
||||
response = requests.post(
|
||||
"https://{}/consider_arrangement".format(node.rest_url()),
|
||||
data=a_normal_arrangement, verify=False)
|
||||
|
|
|
@ -147,7 +147,7 @@ def test_retrieve_cfrags(federated_porter,
|
|||
cleartext_with_sig_header = federated_bob._crypto_power.power_ups(DecryptingPower).keypair.decrypt(policy_message_kit)
|
||||
sig_header, remainder = default_constant_splitter(cleartext_with_sig_header, return_remainder=True)
|
||||
signature_from_kit, cleartext = signature_splitter(remainder, return_remainder=True)
|
||||
assert signature_from_kit.verify(message=cleartext, verifying_key=policy_message_kit.sender_verifying_key)
|
||||
assert signature_from_kit.verify(message=cleartext, verifying_pk=policy_message_kit.sender_verifying_key)
|
||||
assert cleartext == original_message
|
||||
|
||||
#
|
||||
|
|
|
@ -14,6 +14,8 @@ GNU Affero General Public License for more details.
|
|||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
from base64 import b64encode
|
||||
|
||||
import pytest
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Tuple, Any, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from nucypher.utilities.versioning import Versioned
|
||||
|
||||
|
||||
def _check_valid_version_tuple(version: Any, cls: Type):
|
||||
if not isinstance(version, tuple):
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
|
||||
if not len(version) == Versioned._PARTS:
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._PARTS)}-tuple")
|
||||
if not all(isinstance(part, int) for part in version):
|
||||
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
|
||||
|
||||
|
||||
class A(Versioned):
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
@classmethod
|
||||
def _brand(cls):
|
||||
return b"AA"
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
return 2, 1
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return bytes(self.x)
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls):
|
||||
return {
|
||||
(2, 0): cls._from_bytes_v2_0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_v2_0(cls, data):
|
||||
return cls(int(data, 16)) # then we switched to the hexadecimal
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
return cls(str(int(data, 16))) # but now we use a string representation for some reason
|
||||
|
||||
|
||||
def test_unique_branding():
|
||||
brands = tuple(v._brand() for v in Versioned.__subclasses__())
|
||||
brands_set = set(brands)
|
||||
if len(brands) != len(brands_set):
|
||||
duplicate_brands = list(brands)
|
||||
for brand in brands_set:
|
||||
duplicate_brands.remove(brand)
|
||||
pytest.fail(f"Duplicated brand(s) {duplicate_brands}.")
|
||||
|
||||
|
||||
def test_valid_branding():
|
||||
for cls in Versioned.__subclasses__():
|
||||
if len(cls._brand()) != cls._BRAND_SIZE:
|
||||
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
|
||||
if not cls._brand().isalpha():
|
||||
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
|
||||
|
||||
|
||||
def test_valid_version_implementation():
|
||||
for cls in Versioned.__subclasses__():
|
||||
_check_valid_version_tuple(version=cls._version(), cls=cls)
|
||||
|
||||
|
||||
def test_valid_old_handlers_index():
|
||||
for cls in Versioned.__subclasses__():
|
||||
for version in cls._deserializers():
|
||||
_check_valid_version_tuple(version=version, cls=cls)
|
||||
|
||||
|
||||
def test_version_metadata():
|
||||
major, minor = A._version()
|
||||
assert A.version_string() == f'{major}.{minor}'
|
||||
|
||||
|
||||
def test_versioning_header_prepend():
|
||||
a = A(1) # stake sauce
|
||||
assert a.x == 1
|
||||
|
||||
serialized = bytes(a)
|
||||
assert len(serialized) > Versioned._HEADER_SIZE
|
||||
|
||||
header = serialized[:Versioned._HEADER_SIZE]
|
||||
brand = header[:Versioned._BRAND_SIZE]
|
||||
assert brand == A._brand()
|
||||
|
||||
version = header[Versioned._BRAND_SIZE:]
|
||||
major, minor = version[:Versioned._PART_SIZE], version[Versioned._PART_SIZE:]
|
||||
major_number = int.from_bytes(major, 'big')
|
||||
minor_number = int.from_bytes(minor, 'big')
|
||||
assert (major_number, minor_number) == A._version()
|
||||
|
||||
|
||||
def test_versioning_input_too_short():
|
||||
empty = b'AA\x00\x01'
|
||||
with pytest.raises(ValueError, match='Invalid bytes for A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_empty_payload():
|
||||
empty = b'AA\x00\x02\x00\x01'
|
||||
with pytest.raises(ValueError, match='No content to deserialize A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_invalid_brand():
|
||||
invalid = b'\x00\x03\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
||||
A.from_bytes(invalid)
|
||||
|
||||
|
||||
def test_versioning_incorrect_brand():
|
||||
incorrect = b'AB\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'AA', Got b'AB'."):
|
||||
A.from_bytes(incorrect)
|
||||
|
||||
|
||||
def test_unknown_future_major_version():
|
||||
empty = b'AA\x00\x03\x00\x0212'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_incompatible_old_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'AA\x00\x01\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_incompatible_future_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'AA\x00\x03\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_resolve_version():
|
||||
# past
|
||||
v2_0 = 2, 0
|
||||
resolved_version = A._resolve_version(version=v2_0)
|
||||
assert resolved_version == v2_0
|
||||
|
||||
# present
|
||||
v2_1 = 2, 1
|
||||
resolved_version = A._resolve_version(version=v2_1)
|
||||
assert resolved_version == v2_1
|
||||
|
||||
# future minor version resolves to the latest minor version.
|
||||
v2_2 = 2, 2
|
||||
resolved_version = A._resolve_version(version=v2_2)
|
||||
assert resolved_version == v2_1
|
||||
|
||||
|
||||
def test_old_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
# Old minor version
|
||||
v2_0_data = b'AA\x00\x02\x00\x0012'
|
||||
a = A.from_bytes(v2_0_data)
|
||||
assert a.x == 18
|
||||
|
||||
# Old minor version was correctly routed to the v2.0 handler.
|
||||
assert v2_0_spy.call_count == 1
|
||||
v2_0_spy.assert_called_with(b'12')
|
||||
assert not current_spy.call_count
|
||||
|
||||
|
||||
def test_current_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_1_data = b'AA\x00\x02\x00\x0112'
|
||||
a = A.from_bytes(v2_1_data)
|
||||
assert a.x == '18'
|
||||
|
||||
# Current version was correctly routed to the v2.1 handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'12')
|
||||
assert not v2_0_spy.call_count
|
||||
|
||||
|
||||
def test_future_minor_version_handler_routing(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_2_data = b'AA\x00\x02\x02\x0112'
|
||||
a = A.from_bytes(v2_2_data)
|
||||
assert a.x == '18'
|
||||
|
||||
# Future minor version was correctly routed to
|
||||
# the current minor version handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'12')
|
||||
assert not v2_0_spy.call_count
|
Loading…
Reference in New Issue