Merge pull request #2767 from KPrasch/versioning

Versioning of bytes serializable protocol entities
pull/2804/head
KPrasch 2021-09-28 19:13:51 -07:00 committed by GitHub
commit a9b2a8d412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 735 additions and 220 deletions

View File

@ -0,0 +1 @@
Uniform versioning of bytes serializable protocol entities.

View File

@ -17,7 +17,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
from collections import defaultdict
import random
from typing import Dict, Sequence, List
from typing import Dict, Sequence, List, Tuple
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
from eth_typing.evm import ChecksumAddress
@ -43,6 +43,7 @@ from nucypher.network.nodes import Learner
from nucypher.policy.hrac import HRAC, hrac_splitter
from nucypher.policy.kits import MessageKit, RetrievalKit, RetrievalResult
from nucypher.policy.maps import TreasureMap
from nucypher.utilities.versioning import Versioned
class RetrievalPlan:
@ -138,16 +139,11 @@ class RetrievalWorkOrder:
self.capsules = capsules
class ReencryptionRequest:
class ReencryptionRequest(Versioned):
"""
A request for an Ursula to reencrypt for several capsules.
"""
_splitter = (hrac_splitter +
key_splitter +
key_splitter +
BytestringSplitter((MessageKit, VariableLengthBytestring)))
@classmethod
def from_work_order(cls,
work_order: RetrievalWorkOrder,
@ -159,8 +155,7 @@ class ReencryptionRequest:
alice_verifying_key=alice_verifying_key,
bob_verifying_key=bob_verifying_key,
encrypted_kfrag=treasure_map.destinations[work_order.ursula_address],
capsules=work_order.capsules,
)
capsules=work_order.capsules)
def __init__(self,
hrac: HRAC,
@ -175,20 +170,6 @@ class ReencryptionRequest:
self.encrypted_kfrag = encrypted_kfrag
self.capsules = capsules
def __bytes__(self):
return (bytes(self.hrac) +
bytes(self._alice_verifying_key) +
bytes(self._bob_verifying_key) +
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
b''.join(bytes(capsule) for capsule in self.capsules)
)
@classmethod
def from_bytes(cls, data: bytes):
hrac, alice_vk, bob_vk, ekfrag, remainder = cls._splitter(data, return_remainder=True)
capsules = capsule_splitter.repeat(remainder)
return cls(hrac, alice_vk, bob_vk, ekfrag, capsules)
def alice(self) -> 'Alice':
from nucypher.characters.lawful import Alice
return Alice.from_public_keys(verifying_key=self._alice_verifying_key)
@ -201,8 +182,39 @@ class ReencryptionRequest:
from nucypher.characters.lawful import Alice
return Alice.from_public_keys(verifying_key=self.encrypted_kfrag.sender_verifying_key)
def _payload(self) -> bytes:
return (bytes(self.hrac) +
bytes(self._alice_verifying_key) +
bytes(self._bob_verifying_key) +
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
b''.join(bytes(capsule) for capsule in self.capsules)
)
class ReencryptionResponse:
@classmethod
def _brand(cls) -> bytes:
return b'RQ'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
splitter = (hrac_splitter +
key_splitter +
key_splitter +
BytestringSplitter((MessageKit, VariableLengthBytestring)))
hrac, alice_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True)
capsules = capsule_splitter.repeat(remainder)
return cls(hrac, alice_vk, bob_vk, ekfrag, capsules)
class ReencryptionResponse(Versioned):
"""
A response from Ursula with reencrypted capsule frags.
"""
@ -226,21 +238,34 @@ class ReencryptionResponse:
self.cfrags = cfrags
self.signature = signature
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
@classmethod
def from_bytes(cls, data: bytes):
def _brand(cls) -> bytes:
return b'RR'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
signature, cfrags_bytes = signature_splitter(data, return_remainder=True)
# We would never send a request with no capsules, so there should be cfrags.
# The splitter would fail anyway, this just makes the error message more clear.
if not cfrags_bytes:
raise ValueError("ReencryptionResponse contains no cfrags")
raise ValueError(f"{cls.__name__} contains no cfrags")
cfrags = cfrag_splitter.repeat(cfrags_bytes)
return cls(cfrags, signature)
def __bytes__(self):
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
class RetrievalClient:
"""

View File

@ -40,7 +40,7 @@ from nucypher.network.protocols import InterfaceInfo
from nucypher.network.retrieval import ReencryptionRequest, ReencryptionResponse
from nucypher.policy.hrac import HRAC
from nucypher.policy.kits import MessageKit
from nucypher.policy.revocation import Revocation
from nucypher.policy.revocation import RevocationOrder
from nucypher.utilities.logging import Logger
HERE = BASE_DIR = Path(__file__).parent
@ -272,7 +272,7 @@ def _make_rest_app(datastore: Datastore, this_node, domain: str, log: Logger) ->
@rest_app.route('/revoke', methods=['POST'])
def revoke():
revocation = Revocation.from_bytes(request.data)
revocation = RevocationOrder.from_bytes(request.data)
# TODO: Implement offchain revocation.
return Response(status=200)

View File

@ -16,30 +16,30 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Dict, Optional, Iterable, Set
from typing import Dict, Optional, Iterable, Set, Tuple
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
from constant_sorrow.constants import (
NOT_SIGNED,
DO_NOT_SIGN,
SIGNATURE_TO_FOLLOW,
SIGNATURE_IS_ON_CIPHERTEXT,
NOT_SIGNED,
)
)
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address, to_canonical_address
import nucypher.crypto.umbral_adapter as umbral # need it to mock `umbral.encrypt`
from nucypher.crypto.splitters import (
capsule_splitter,
key_splitter,
signature_splitter,
checksum_address_splitter,
)
import nucypher.crypto.umbral_adapter as umbral # need it to mock `umbral.encrypt`
)
from nucypher.crypto.umbral_adapter import PublicKey, VerifiedCapsuleFrag, Capsule, Signature
from nucypher.utilities.versioning import Versioned
class MessageKit:
class MessageKit(Versioned):
"""
All the components needed to transmit and verify an encrypted message.
"""
@ -106,7 +106,13 @@ class MessageKit:
def __str__(self):
return f"{self.__class__.__name__}({self.capsule})"
def __bytes__(self):
def as_policy_kit(self, policy_key: PublicKey, threshold: int) -> 'PolicyMessageKit':
return PolicyMessageKit.from_message_kit(self, policy_key, threshold)
def as_retrieval_kit(self) -> 'RetrievalKit':
return RetrievalKit(self.capsule, set())
def _payload(self) -> bytes:
# TODO (#2743): this logic may not be necessary depending on the resolution.
# If it is, it is better moved to BytestringSplitter.
return (bytes(self.capsule) +
@ -115,7 +121,19 @@ class MessageKit:
VariableLengthBytestring(self.ciphertext))
@classmethod
def from_bytes(cls, data: bytes):
def _brand(cls) -> bytes:
return b'MK'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(
capsule_splitter,
(bytes, 1))
@ -129,8 +147,7 @@ class MessageKit:
else:
raise ValueError("Incorrect format for the signature flag")
splitter = BytestringSplitter(
(bytes, 1))
splitter = BytestringSplitter((bytes, 1))
key_flag, remainder = splitter(remainder, return_remainder=True)
@ -145,14 +162,8 @@ class MessageKit:
return cls(capsule, ciphertext, signature=signature, sender_verifying_key=sender_verifying_key)
def as_policy_kit(self, policy_key: PublicKey, threshold: int) -> 'PolicyMessageKit':
return PolicyMessageKit.from_message_kit(self, policy_key, threshold)
def as_retrieval_kit(self) -> 'RetrievalKit':
return RetrievalKit(self.capsule, set())
class RetrievalKit:
class RetrievalKit(Versioned):
"""
An object encapsulating the information necessary for retrieval of cfrags from Ursulas.
Contains the capsule and the checksum addresses of Ursulas from which the requester
@ -164,12 +175,24 @@ class RetrievalKit:
# Can store cfrags too, if we're worried about Ursulas supplying duplicate ones.
self.queried_addresses = set(queried_addresses)
def __bytes__(self):
def _payload(self) -> bytes:
return (bytes(self.capsule) +
b''.join(to_canonical_address(address) for address in self.queried_addresses))
@classmethod
def from_bytes(cls, data: bytes):
def _brand(cls) -> bytes:
return b'RK'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
capsule, remainder = capsule_splitter(data, return_remainder=True)
if remainder:
addresses_as_bytes = checksum_address_splitter.repeat(remainder)

View File

@ -16,7 +16,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Optional, Callable, Sequence, Dict
from typing import Optional, Callable, Sequence, Dict, Tuple
from bytestring_splitter import (
BytestringSplitter,
@ -36,24 +36,40 @@ from nucypher.crypto.utils import keccak_digest, verify_eip_191
from nucypher.network.middleware import RestMiddleware
from nucypher.policy.hrac import HRAC, hrac_splitter
from nucypher.policy.kits import MessageKit
from nucypher.utilities.versioning import Versioned
class TreasureMap:
class TreasureMap(Versioned):
class NowhereToBeFound(RestMiddleware.NotFound):
"""
Called when no known nodes have it.
"""
main_splitter = BytestringSplitter(
(int, 1, {'byteorder': 'big'}),
hrac_splitter,
)
def __init__(self,
threshold: int,
hrac: HRAC,
destinations: Dict[ChecksumAddress, MessageKit]):
self.threshold = threshold
self.destinations = destinations
self.hrac = hrac
ursula_and_kfrag_payload_splitter = BytestringSplitter(
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
(MessageKit, VariableLengthBytestring),
)
# A little awkward, but saves us a key length in serialization
self.publisher_verifying_key = list(destinations.values())[0].sender_verifying_key
def __iter__(self):
return iter(self.destinations.items())
def __len__(self):
return len(self.destinations)
def __eq__(self, other):
if not isinstance(other, TreasureMap):
return False
return (self.threshold == other.threshold and
self.hrac == other.hrac and
self.destinations == other.destinations)
@classmethod
def construct_by_publisher(cls,
@ -87,17 +103,42 @@ class TreasureMap:
return cls(threshold=threshold, hrac=hrac, destinations=destinations)
def __init__(self,
threshold: int,
hrac: HRAC,
destinations: Dict[ChecksumAddress, MessageKit],
):
self.threshold = threshold
self.destinations = destinations
self.hrac = hrac
@classmethod
def _brand(cls) -> bytes:
return b'TM'
# A little awkward, but saves us a key length in serialization
self.publisher_verifying_key = list(destinations.values())[0].sender_verifying_key
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return self.threshold.to_bytes(1, "big") + bytes(self.hrac) + self._nodes_as_bytes()
@classmethod
def _from_bytes_current(cls, data):
main_splitter = BytestringSplitter(
(int, 1, {'byteorder': 'big'}),
hrac_splitter,
)
ursula_and_kfrag_payload_splitter = BytestringSplitter(
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
(MessageKit, VariableLengthBytestring),
)
try:
threshold, hrac, remainder = main_splitter(data, return_remainder=True)
ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder)
except BytestringSplittingError as e:
raise ValueError('Invalid treasure map contents.') from e
destinations = {u: k for u, k in ursula_and_kfrags}
return cls(threshold, hrac, destinations)
def encrypt(self,
publisher: 'Alice',
@ -117,52 +158,22 @@ class TreasureMap:
nodes_as_bytes += (node_id + kfrag)
return nodes_as_bytes
def __bytes__(self):
return self.threshold.to_bytes(1, "big") + bytes(self.hrac) + self._nodes_as_bytes()
@classmethod
def from_bytes(cls, data: bytes):
try:
threshold, hrac, remainder = cls.main_splitter(data, return_remainder=True)
ursula_and_kfrags = cls.ursula_and_kfrag_payload_splitter.repeat(remainder)
except BytestringSplittingError as e:
raise ValueError('Invalid treasure map contents.') from e
destinations = {u: k for u, k in ursula_and_kfrags}
return cls(threshold, hrac, destinations)
def __eq__(self, other):
if not isinstance(other, TreasureMap):
return False
return (self.threshold == other.threshold and
self.hrac == other.hrac and
self.destinations == other.destinations)
def __iter__(self):
return iter(self.destinations.items())
def __len__(self):
return len(self.destinations)
class AuthorizedKeyFrag:
class AuthorizedKeyFrag(Versioned):
_WRIT_CHECKSUM_SIZE = 32
# The size of a serialized message kit encrypting an AuthorizedKeyFrag.
# Depends on encryption parameters in Umbral, has to be hardcoded.
ENCRYPTED_SIZE = 621
SERIALIZED_SIZE = Versioned._HEADER_SIZE + ENCRYPTED_SIZE
_splitter = BytestringSplitter(
hrac_splitter, # HRAC
(bytes, _WRIT_CHECKSUM_SIZE), # kfrag checksum
signature_splitter, # Publisher's signature
kfrag_splitter,
)
@staticmethod
def _kfrag_checksum(kfrag: KeyFrag) -> bytes:
return keccak_digest(bytes(kfrag))[:AuthorizedKeyFrag._WRIT_CHECKSUM_SIZE]
def __init__(self, hrac: HRAC, kfrag_checksum: bytes, writ_signature: Signature, kfrag: KeyFrag):
self.hrac = hrac
self.kfrag_checksum = kfrag_checksum
self.writ = bytes(hrac) + kfrag_checksum
self.writ_signature = writ_signature
self.kfrag = kfrag
@classmethod
def construct_by_publisher(cls,
@ -184,20 +195,38 @@ class AuthorizedKeyFrag:
# the material needed for Ursula to assuredly service this policy.
return cls(hrac, kfrag_checksum, writ_signature, kfrag)
def __init__(self, hrac: HRAC, kfrag_checksum: bytes, writ_signature: Signature, kfrag: KeyFrag):
self.hrac = hrac
self.kfrag_checksum = kfrag_checksum
self.writ = bytes(hrac) + kfrag_checksum
self.writ_signature = writ_signature
self.kfrag = kfrag
@staticmethod
def _kfrag_checksum(kfrag: KeyFrag) -> bytes:
return keccak_digest(bytes(kfrag))[:AuthorizedKeyFrag._WRIT_CHECKSUM_SIZE]
def __bytes__(self):
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return self.writ + bytes(self.writ_signature) + bytes(self.kfrag)
@classmethod
def from_bytes(cls, data: bytes):
def _brand(cls) -> bytes:
return b'KF'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
# TODO: should we check the signature right away here?
hrac, kfrag_checksum, writ_signature, kfrag = cls._splitter(data)
splitter = BytestringSplitter(
hrac_splitter, # HRAC
(bytes, cls._WRIT_CHECKSUM_SIZE), # kfrag checksum
signature_splitter, # Publisher's signature
kfrag_splitter,
)
hrac, kfrag_checksum, writ_signature, kfrag = splitter(data)
# Check integrity
calculated_checksum = cls._kfrag_checksum(kfrag)
@ -207,19 +236,24 @@ class AuthorizedKeyFrag:
return cls(hrac, kfrag_checksum, writ_signature, kfrag)
class EncryptedTreasureMap:
_splitter = BytestringSplitter(
signature_splitter, # public signature
hrac_splitter, # HRAC
(MessageKit, VariableLengthBytestring), # encrypted TreasureMap
(bytes, EIP712_MESSAGE_SIGNATURE_SIZE)) # blockchain signature
class EncryptedTreasureMap(Versioned):
_EMPTY_BLOCKCHAIN_SIGNATURE = b'\x00' * EIP712_MESSAGE_SIGNATURE_SIZE
# TODO: do we really need this alias?
from nucypher.crypto.signing import \
InvalidSignature # Raised when the public signature (typically intended for Ursula) is not valid.
# Raised when the public signature (typically intended for Ursula) is not valid.
from nucypher.crypto.signing import InvalidSignature
def __init__(self,
hrac: HRAC,
public_signature: Signature,
encrypted_tmap: MessageKit,
blockchain_signature: Optional[bytes] = None):
self.hrac = hrac
self._public_signature = public_signature
self.publisher_verifying_key = encrypted_tmap.sender_verifying_key
self._encrypted_tmap = encrypted_tmap
self._blockchain_signature = blockchain_signature
@staticmethod
def _sign(blockchain_signer: Callable[[bytes], bytes],
@ -258,19 +292,6 @@ class EncryptedTreasureMap:
return cls(treasure_map.hrac, public_signature, encrypted_tmap, blockchain_signature=blockchain_signature)
def __init__(self,
hrac: HRAC,
public_signature: Signature,
encrypted_tmap: MessageKit,
blockchain_signature: Optional[bytes] = None,
):
self.hrac = hrac
self._public_signature = public_signature
self.publisher_verifying_key = encrypted_tmap.sender_verifying_key
self._encrypted_tmap = encrypted_tmap
self._blockchain_signature = blockchain_signature
def decrypt(self, decryptor: Callable[[bytes], bytes]) -> TreasureMap:
"""
When Bob receives the TreasureMap, he'll pass a decryptor (a callable which can verify and decrypt the
@ -284,17 +305,6 @@ class EncryptedTreasureMap:
return TreasureMap.from_bytes(map_in_the_clear)
def __bytes__(self):
if self._blockchain_signature:
signature = self._blockchain_signature
else:
signature = self._EMPTY_BLOCKCHAIN_SIGNATURE
return (bytes(self._public_signature) +
bytes(self.hrac) +
bytes(VariableLengthBytestring(bytes(self._encrypted_tmap))) +
signature
)
def verify_blockchain_signature(self, checksum_address: ChecksumAddress) -> bool:
if self._blockchain_signature is None:
raise ValueError("This EncryptedTreasureMap is not blockchain-signed")
@ -308,10 +318,39 @@ class EncryptedTreasureMap:
if not self._public_signature.verify(self.publisher_verifying_key, message=message):
raise self.InvalidSignature("This TreasureMap is not properly publicly signed by the publisher.")
def _payload(self) -> bytes:
if self._blockchain_signature:
signature = self._blockchain_signature
else:
signature = self._EMPTY_BLOCKCHAIN_SIGNATURE
return (bytes(self._public_signature) +
bytes(self.hrac) +
bytes(VariableLengthBytestring(bytes(self._encrypted_tmap))) +
signature)
@classmethod
def from_bytes(cls, data: bytes):
def _brand(cls) -> bytes:
return b'EM'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(
signature_splitter, # public signature
hrac_splitter, # HRAC
(MessageKit, VariableLengthBytestring), # encrypted TreasureMap
(bytes, EIP712_MESSAGE_SIGNATURE_SIZE)) # blockchain signature
try:
public_signature, hrac, message_kit, blockchain_signature = cls._splitter(data)
public_signature, hrac, message_kit, blockchain_signature = splitter(data)
if blockchain_signature == cls._EMPTY_BLOCKCHAIN_SIGNATURE:
blockchain_signature = None
except BytestringSplittingError as e:

View File

@ -38,24 +38,18 @@ from nucypher.policy.reservoir import (
from nucypher.policy.revocation import RevocationKit
from nucypher.utilities.concurrency import WorkerPool
from nucypher.utilities.logging import Logger
from nucypher.utilities.versioning import Versioned
class Arrangement:
"""
A contract between Alice and a single Ursula.
"""
splitter = BytestringSplitter(
key_splitter, # publisher_verifying_key
(bytes, VariableLengthBytestring) # expiration
)
class Arrangement(Versioned):
"""A contract between Alice and a single Ursula."""
def __init__(self, publisher_verifying_key: PublicKey, expiration: maya.MayaDT):
self.expiration = expiration
self.publisher_verifying_key = publisher_verifying_key
def __bytes__(self):
return bytes(self.publisher_verifying_key) + bytes(VariableLengthBytestring(self.expiration.iso8601().encode()))
def __repr__(self):
return f"Arrangement(publisher={self.publisher_verifying_key})"
@classmethod
def from_publisher(cls, publisher: 'Alice', expiration: maya.MayaDT) -> 'Arrangement':
@ -63,14 +57,31 @@ class Arrangement:
return cls(publisher_verifying_key=publisher_verifying_key, expiration=expiration)
@classmethod
def from_bytes(cls, arrangement_as_bytes: bytes) -> 'Arrangement':
publisher_verifying_key, expiration_bytes = cls.splitter(arrangement_as_bytes)
def _brand(cls) -> bytes:
return b'AR'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return bytes(self.publisher_verifying_key) + bytes(VariableLengthBytestring(self.expiration.iso8601().encode()))
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
@classmethod
def _from_bytes_current(cls, data: bytes):
splitter = BytestringSplitter(
key_splitter, # publisher_verifying_key
(bytes, VariableLengthBytestring) # expiration
)
publisher_verifying_key, expiration_bytes = splitter(data)
expiration = maya.MayaDT.from_iso8601(iso8601_string=expiration_bytes.decode())
return cls(publisher_verifying_key=publisher_verifying_key, expiration=expiration)
def __repr__(self):
return f"Arrangement(publisher={self.publisher_verifying_key})"
class Policy(ABC):
"""

View File

@ -16,7 +16,7 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Optional
from typing import Optional, Dict, Tuple
from bytestring_splitter import BytestringSplitter
from eth_typing.evm import ChecksumAddress
@ -27,9 +27,10 @@ from nucypher.crypto.splitters import signature_splitter, checksum_address_split
from nucypher.crypto.umbral_adapter import Signature, PublicKey
from nucypher.policy.kits import MessageKit
from nucypher.policy.maps import AuthorizedKeyFrag
from nucypher.utilities.versioning import Versioned
class Revocation:
class RevocationOrder(Versioned):
"""
Represents a string used by characters to perform a revocation on a specific
Ursula. It's a bytestring made of the following format:
@ -37,16 +38,9 @@ class Revocation:
This is sent as a payload in a DELETE method to the /KFrag/ endpoint.
"""
PREFIX = b'REVOKE-'
revocation_splitter = BytestringSplitter(
(bytes, len(PREFIX)),
checksum_address_splitter, # ursula canonical address
(bytes, AuthorizedKeyFrag.ENCRYPTED_SIZE), # encrypted kfrag payload (includes writ)
signature_splitter
)
def __init__(self,
ursula_checksum_address: ChecksumAddress, # TODO: Use staker address instead (what if the staker rebonds)?
ursula_checksum_address: ChecksumAddress,
# TODO: Use staker address instead (what if the staker rebonds)?
encrypted_kfrag: MessageKit,
signer: Optional[SignatureStamp] = None,
signature: Optional[Signature] = None):
@ -57,13 +51,10 @@ class Revocation:
if not (bool(signer) ^ bool(signature)):
raise ValueError("Either pass a signer or a signature; not both.")
elif signer:
self.signature = signer(self.payload)
self.signature = signer(self._body())
elif signature:
self.signature = signature
def __bytes__(self):
return self.payload + bytes(self.signature)
def __repr__(self):
return bytes(self)
@ -73,37 +64,55 @@ class Revocation:
def __eq__(self, other):
return bytes(self) == bytes(other)
@property
def payload(self):
return self.PREFIX \
+ to_canonical_address(self.ursula_checksum_address) \
+ bytes(self.encrypted_kfrag) \
@classmethod
def from_bytes(cls, revocation_bytes):
prefix, ursula_canonical_address, ekfrag, signature = cls.revocation_splitter(revocation_bytes)
ursula_checksum_address = to_checksum_address(ursula_canonical_address)
return cls(ursula_checksum_address=ursula_checksum_address,
encrypted_kfrag=ekfrag,
signature=signature)
def verify_signature(self, alice_verifying_key: PublicKey) -> bool:
"""
Verifies the revocation was from the provided pubkey.
"""
if not self.signature.verify(self.payload, alice_verifying_key):
if not self.signature.verify(self._body(), alice_verifying_key):
raise InvalidSignature(f"Revocation has an invalid signature: {self.signature}")
return True
@classmethod
def _brand(cls) -> bytes:
return b'RV'
@classmethod
def _version(cls) -> Tuple[int, int]:
return 1, 0
@classmethod
def _old_version_handlers(cls) -> Dict:
return {}
def _body(self) -> bytes:
return to_canonical_address(self.ursula_checksum_address) + bytes(self.encrypted_kfrag)
def _payload(self) -> bytes:
return self._body() + bytes(self.signature)
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(
checksum_address_splitter, # ursula canonical address
(bytes, Versioned._HEADER_SIZE+AuthorizedKeyFrag.SERIALIZED_SIZE), # MessageKit version header + versioned ekfrag
signature_splitter
)
ursula_canonical_address, ekfrag, signature = splitter(data)
ursula_checksum_address = to_checksum_address(ursula_canonical_address)
return cls(ursula_checksum_address=ursula_checksum_address,
encrypted_kfrag=ekfrag,
signature=signature)
class RevocationKit:
def __init__(self, treasure_map, signer: SignatureStamp):
self.revocations = dict()
for node_id, encrypted_kfrag in treasure_map:
self.revocations[node_id] = Revocation(ursula_checksum_address=node_id,
encrypted_kfrag=encrypted_kfrag,
signer=signer)
self.revocations[node_id] = RevocationOrder(ursula_checksum_address=node_id,
encrypted_kfrag=encrypted_kfrag,
signer=signer)
def __iter__(self):
return iter(self.revocations.values())

View File

@ -0,0 +1,166 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from abc import abstractmethod, ABC
from typing import Dict, Tuple, Callable
class Versioned(ABC):
"""Base class for serializable entities"""
_PARTS = 2 # bytes
_PART_SIZE = 2
_BRAND_SIZE = 2
_VERSION_SIZE = _PART_SIZE * _PARTS
_HEADER_SIZE = _BRAND_SIZE + _VERSION_SIZE
class InvalidHeader(ValueError):
"""Raised when an unexpected or invalid bytes header is encountered."""
class IncompatibleVersion(ValueError):
"""Raised when attempting to deserialize incompatible bytes"""
class Empty(ValueError):
"""Raised when 0 bytes are remaining after parsing the header."""
@classmethod
@abstractmethod
def _brand(cls) -> bytes:
raise NotImplementedError
@classmethod
@abstractmethod
def _version(cls) -> Tuple[int, int]:
"""tuple(major, minor)"""
raise NotImplementedError
@classmethod
def version_string(cls) -> str:
major, minor = cls._version()
return f'{major}.{minor}'
#
# Serialize
#
def __bytes__(self) -> bytes:
return self._header() + self._payload()
@classmethod
def _header(cls) -> bytes:
"""The entire bytes header to prepend to the instance payload."""
major, minor = cls._version()
major_bytes = major.to_bytes(cls._PART_SIZE, 'big')
minor_bytes = minor.to_bytes(cls._PART_SIZE, 'big')
header = cls._brand() + major_bytes + minor_bytes
return header
@abstractmethod
def _payload(self) -> bytes:
"""The unbranded and unversioned bytes-serialized representation of this instance."""
raise NotImplementedError
#
# Deserialize
#
@classmethod
@abstractmethod
def _from_bytes_current(cls, data):
"""The current deserializer"""
raise NotImplementedError
@classmethod
@abstractmethod
def _old_version_handlers(cls) -> Dict[Tuple[int, int], Callable]:
"""Old deserializer callables keyed by version."""
raise NotImplementedError
@classmethod
def from_bytes(cls, data: bytes):
""""Public deserialization API"""
brand, version, payload = cls._parse_header(data)
version = cls._resolve_version(version=version)
handlers = cls._deserializers()
return handlers[version](payload)
@classmethod
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:
# Unpack version metadata
bytrestring_major, bytrestring_minor = version
latest_major_version, latest_minor_version = cls._version()
# Enforce major version compatibility
if not bytrestring_major == latest_major_version:
message = f'Incompatible versioned bytes for {cls.__name__}. ' \
f'Compatible version is {latest_major_version}.x, ' \
f'Got {bytrestring_major}.{bytrestring_minor}.'
raise cls.IncompatibleVersion(message)
# Enforce minor version compatibility.
# Pass future minor versions to the latest minor handler.
if bytrestring_minor >= latest_minor_version:
version = cls._version()
return version
@classmethod
def _parse_header(cls, data: bytes) -> Tuple[bytes, Tuple[int, int], bytes]:
if len(data) < cls._HEADER_SIZE:
# handles edge case when input is too short.
raise ValueError(f'Invalid bytes for {cls.__name__}.')
brand = cls._parse_brand(data)
version = cls._parse_version(data)
payload = cls._parse_payload(data)
return brand, version, payload
@classmethod
def _parse_brand(cls, data: bytes) -> bytes:
brand = data[:cls._BRAND_SIZE]
if brand != cls._brand():
error = f"Incorrect brand. Expected {cls._brand()}, Got {brand}."
if not brand.isalpha():
# unversioned entities for older versions will most likely land here.
error = f"Incompatible bytes for {cls.__name__}."
raise cls.InvalidHeader(error)
return brand
@classmethod
def _parse_version(cls, data: bytes) -> Tuple[int, int]:
version_data = data[cls._BRAND_SIZE:cls._HEADER_SIZE]
major, minor = version_data[:cls._PART_SIZE], version_data[cls._PART_SIZE:]
major, minor = int.from_bytes(major, 'big'), int.from_bytes(minor, 'big')
version = major, minor
return version
@classmethod
def _parse_payload(cls, data: bytes) -> bytes:
payload = data[cls._HEADER_SIZE:]
if len(payload) == 0:
raise ValueError(f'No content to deserialize {cls.__name__}.')
return payload
@classmethod
def _deserializers(cls) -> Dict[Tuple[int, int], Callable]:
"""Return a dict of all known deserialization handlers for this class keyed by version"""
return {cls._version(): cls._from_bytes_current, **cls._old_version_handlers()}
# Collects the brands of every serializable entity, potentially useful for documentation.
# SERIALIZABLE_ENTITIES = {v.__class__.__name__: v._brand() for v in Versioned.__subclasses__()}

View File

@ -149,7 +149,7 @@ def test_retrieve_cfrags(blockchain_porter,
cleartext_with_sig_header = blockchain_bob._crypto_power.power_ups(DecryptingPower).keypair.decrypt(policy_message_kit)
sig_header, remainder = default_constant_splitter(cleartext_with_sig_header, return_remainder=True)
signature_from_kit, cleartext = signature_splitter(remainder, return_remainder=True)
assert signature_from_kit.verify(message=cleartext, verifying_key=policy_message_kit.sender_verifying_key)
assert signature_from_kit.verify(message=cleartext, verifying_pk=policy_message_kit.sender_verifying_key)
assert cleartext == original_message
#

View File

@ -24,7 +24,7 @@ import pytest
from nucypher.characters.lawful import Enrico
from nucypher.crypto.utils import keccak_digest
from nucypher.policy.kits import MessageKit
from nucypher.policy.revocation import Revocation
from nucypher.policy.revocation import RevocationOrder
def test_federated_grant(federated_alice, federated_bob, federated_ursulas):
@ -113,7 +113,7 @@ def test_revocation(federated_alice, federated_bob):
# Test Revocation deserialization
revocation = policy.revocation_kit[node_id]
revocation_bytes = bytes(revocation)
deserialized_revocation = Revocation.from_bytes(revocation_bytes)
deserialized_revocation = RevocationOrder.from_bytes(revocation_bytes)
assert deserialized_revocation == revocation
# Attempt to revoke the new policy

View File

@ -16,8 +16,8 @@
"""
import base64
import datetime
from base64 import b64encode
import maya
import pytest
@ -29,6 +29,7 @@ from nucypher.control.specifications.base import BaseSchema
from nucypher.control.specifications.exceptions import SpecificationError, InvalidInputData, InvalidArgumentCombo
from nucypher.crypto.powers import DecryptingPower
from nucypher.crypto.umbral_adapter import PublicKey
from nucypher.policy.kits import MessageKit
from nucypher.policy.kits import MessageKit as MessageKitClass
from nucypher.policy.maps import EncryptedTreasureMap as EncryptedTreasureMapClass, TreasureMap as TreasureMapClass
@ -89,16 +90,19 @@ def test_treasure_map_validation(enacted_federated_policy,
assert "Could not parse tmap" in str(e)
assert "Invalid base64-encoded string" in str(e)
base64_header = base64.b64encode(EncryptedTreasureMapClass._header()).decode()
# valid base64 but invalid treasuremap
bad_map = base64_header + "VGhpcyBpcWgb3RhbGx5IG5vdCBhIHRyZWFzdXJlbWg=="
with pytest.raises(InvalidInputData) as e:
EncryptedTreasureMapsOnly().load({'tmap': "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="})
EncryptedTreasureMapsOnly().load({'tmap': bad_map})
assert "Could not convert input for tmap to an EncryptedTreasureMap" in str(e)
assert "Invalid encrypted treasure map contents." in str(e)
# a valid treasuremap for once...
tmap_bytes = bytes(enacted_federated_policy.treasure_map)
tmap_b64 = b64encode(tmap_bytes)
tmap_b64 = base64.b64encode(tmap_bytes)
result = EncryptedTreasureMapsOnly().load({'tmap': tmap_b64.decode()})
assert isinstance(result['tmap'], EncryptedTreasureMapClass)
@ -117,8 +121,10 @@ def test_treasure_map_validation(enacted_federated_policy,
assert "Invalid base64-encoded string" in str(e)
# valid base64 but invalid treasuremap
base64_header = base64.b64encode(TreasureMapClass._header()).decode()
bad_map = base64_header + "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="
with pytest.raises(InvalidInputData) as e:
UnenncryptedTreasureMapsOnly().load({'tmap': "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="})
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map})
assert "Could not convert input for tmap to a TreasureMap" in str(e)
assert "Invalid treasure map contents." in str(e)
@ -126,7 +132,7 @@ def test_treasure_map_validation(enacted_federated_policy,
# a valid treasuremap
decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map)
tmap_bytes = bytes(decrypted_treasure_map)
tmap_b64 = b64encode(tmap_bytes).decode()
tmap_b64 = base64.b64encode(tmap_bytes).decode()
result = UnenncryptedTreasureMapsOnly().load({'tmap': tmap_b64})
assert isinstance(result['tmap'], TreasureMapClass)
@ -146,9 +152,10 @@ def test_messagekit_validation(capsule_side_channel):
assert "Could not parse mkit" in str(e)
assert "Incorrect padding" in str(e)
# valid base64 but invalid treasuremap
# valid base64 but invalid messagekit
b64header = base64.b64encode(MessageKit._header()).decode()
with pytest.raises(SpecificationError) as e:
MessageKitsOnly().load({'mkit': "V3da"})
MessageKitsOnly().load({'mkit': b64header + "V3da=="})
assert "Could not parse mkit" in str(e)
assert "Not enough bytes to constitute message types" in str(e)
@ -156,7 +163,7 @@ def test_messagekit_validation(capsule_side_channel):
# test a valid messagekit
valid_kit = capsule_side_channel.messages[0][0]
kit_bytes = bytes(valid_kit)
kit_b64 = b64encode(kit_bytes)
kit_b64 = base64.b64encode(kit_bytes)
result = MessageKitsOnly().load({'mkit': kit_b64.decode()})
assert isinstance(result['mkit'], MessageKitClass)

View File

@ -15,17 +15,20 @@
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import datetime
import maya
import os
from functools import partial
import maya
import pytest
import pytest_twisted
import requests
from bytestring_splitter import BytestringSplittingError
from functools import partial
from twisted.internet import threads
from nucypher.policy.policies import Policy
from nucypher.policy.policies import Policy, Arrangement
from nucypher.utilities.versioning import Versioned
from tests.utils.middleware import EvilMiddleWare, NodeIsDownMiddleware
from tests.utils.ursula import make_federated_ursulas
@ -109,8 +112,9 @@ def test_huge_treasure_maps_are_rejected(federated_alice, federated_ursulas):
firstula = list(federated_ursulas)[0]
header = Arrangement._header()
ok_amount = 10 * 1024 # 10k
ok_data = os.urandom(ok_amount)
ok_data = header + os.urandom(ok_amount)
with pytest.raises(BytestringSplittingError):
federated_alice.network_middleware.upload_arbitrary_data(
@ -143,8 +147,10 @@ def test_hendrix_handles_content_length_validation(ursula_federated_test_config)
node_deployer.catalogServers(node_deployer.hendrix)
node_deployer.start()
header = Arrangement._header()
def check_node_rejects_large_posts(node):
too_much_data = os.urandom(100 * 1024)
too_much_data = header + os.urandom(100 * 1024)
response = requests.post(
"https://{}/consider_arrangement".format(node.rest_url()),
data=too_much_data, verify=False)
@ -153,7 +159,8 @@ def test_hendrix_handles_content_length_validation(ursula_federated_test_config)
return node
def check_node_accepts_normal_posts(node):
a_normal_arrangement = os.urandom(49 * 1024) # 49K, the limit is 50K
under_limit = (49 * 1024)-Versioned._HEADER_SIZE # 49K, the limit is 50K
a_normal_arrangement = header + os.urandom(under_limit)
response = requests.post(
"https://{}/consider_arrangement".format(node.rest_url()),
data=a_normal_arrangement, verify=False)

View File

@ -147,7 +147,7 @@ def test_retrieve_cfrags(federated_porter,
cleartext_with_sig_header = federated_bob._crypto_power.power_ups(DecryptingPower).keypair.decrypt(policy_message_kit)
sig_header, remainder = default_constant_splitter(cleartext_with_sig_header, return_remainder=True)
signature_from_kit, cleartext = signature_splitter(remainder, return_remainder=True)
assert signature_from_kit.verify(message=cleartext, verifying_key=policy_message_kit.sender_verifying_key)
assert signature_from_kit.verify(message=cleartext, verifying_pk=policy_message_kit.sender_verifying_key)
assert cleartext == original_message
#

View File

@ -14,6 +14,8 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from base64 import b64encode
import pytest

View File

@ -0,0 +1,225 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Tuple, Any, Type
import pytest
from nucypher.utilities.versioning import Versioned
def _check_valid_version_tuple(version: Any, cls: Type):
if not isinstance(version, tuple):
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
if not len(version) == Versioned._PARTS:
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._PARTS)}-tuple")
if not all(isinstance(part, int) for part in version):
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
class A(Versioned):
def __init__(self, x):
self.x = x
@classmethod
def _brand(cls):
return b"AA"
@classmethod
def _version(cls) -> Tuple[int, int]:
return 2, 1
def _payload(self) -> bytes:
return bytes(self.x)
@classmethod
def _old_version_handlers(cls):
return {
(2, 0): cls._from_bytes_v2_0,
}
@classmethod
def _from_bytes_v2_0(cls, data):
return cls(int(data, 16)) # then we switched to the hexadecimal
@classmethod
def _from_bytes_current(cls, data):
return cls(str(int(data, 16))) # but now we use a string representation for some reason
def test_unique_branding():
brands = tuple(v._brand() for v in Versioned.__subclasses__())
brands_set = set(brands)
if len(brands) != len(brands_set):
duplicate_brands = list(brands)
for brand in brands_set:
duplicate_brands.remove(brand)
pytest.fail(f"Duplicated brand(s) {duplicate_brands}.")
def test_valid_branding():
for cls in Versioned.__subclasses__():
if len(cls._brand()) != cls._BRAND_SIZE:
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
if not cls._brand().isalpha():
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
def test_valid_version_implementation():
for cls in Versioned.__subclasses__():
_check_valid_version_tuple(version=cls._version(), cls=cls)
def test_valid_old_handlers_index():
for cls in Versioned.__subclasses__():
for version in cls._deserializers():
_check_valid_version_tuple(version=version, cls=cls)
def test_version_metadata():
major, minor = A._version()
assert A.version_string() == f'{major}.{minor}'
def test_versioning_header_prepend():
a = A(1) # stake sauce
assert a.x == 1
serialized = bytes(a)
assert len(serialized) > Versioned._HEADER_SIZE
header = serialized[:Versioned._HEADER_SIZE]
brand = header[:Versioned._BRAND_SIZE]
assert brand == A._brand()
version = header[Versioned._BRAND_SIZE:]
major, minor = version[:Versioned._PART_SIZE], version[Versioned._PART_SIZE:]
major_number = int.from_bytes(major, 'big')
minor_number = int.from_bytes(minor, 'big')
assert (major_number, minor_number) == A._version()
def test_versioning_input_too_short():
empty = b'AA\x00\x01'
with pytest.raises(ValueError, match='Invalid bytes for A.'):
A.from_bytes(empty)
def test_versioning_empty_payload():
empty = b'AA\x00\x02\x00\x01'
with pytest.raises(ValueError, match='No content to deserialize A.'):
A.from_bytes(empty)
def test_versioning_invalid_brand():
invalid = b'\x00\x03\x00\x0112'
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
A.from_bytes(invalid)
def test_versioning_incorrect_brand():
incorrect = b'AB\x00\x0112'
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'AA', Got b'AB'."):
A.from_bytes(incorrect)
def test_unknown_future_major_version():
empty = b'AA\x00\x03\x00\x0212'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
with pytest.raises(ValueError, match=message):
A.from_bytes(empty)
def test_incompatible_old_major_version(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v1_data = b'AA\x00\x01\x00\x0012'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
with pytest.raises(Versioned.IncompatibleVersion, match=message):
A.from_bytes(v1_data)
assert not current_spy.call_count
def test_incompatible_future_major_version(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v1_data = b'AA\x00\x03\x00\x0012'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
with pytest.raises(Versioned.IncompatibleVersion, match=message):
A.from_bytes(v1_data)
assert not current_spy.call_count
def test_resolve_version():
# past
v2_0 = 2, 0
resolved_version = A._resolve_version(version=v2_0)
assert resolved_version == v2_0
# present
v2_1 = 2, 1
resolved_version = A._resolve_version(version=v2_1)
assert resolved_version == v2_1
# future minor version resolves to the latest minor version.
v2_2 = 2, 2
resolved_version = A._resolve_version(version=v2_2)
assert resolved_version == v2_1
def test_old_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
# Old minor version
v2_0_data = b'AA\x00\x02\x00\x0012'
a = A.from_bytes(v2_0_data)
assert a.x == 18
# Old minor version was correctly routed to the v2.0 handler.
assert v2_0_spy.call_count == 1
v2_0_spy.assert_called_with(b'12')
assert not current_spy.call_count
def test_current_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_1_data = b'AA\x00\x02\x00\x0112'
a = A.from_bytes(v2_1_data)
assert a.x == '18'
# Current version was correctly routed to the v2.1 handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'12')
assert not v2_0_spy.call_count
def test_future_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_2_data = b'AA\x00\x02\x02\x0112'
a = A.from_bytes(v2_2_data)
assert a.x == '18'
# Future minor version was correctly routed to
# the current minor version handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'12')
assert not v2_0_spy.call_count