diff --git a/nucypher/core.py b/nucypher/core.py index 7f879a29c..93ac1c701 100644 --- a/nucypher/core.py +++ b/nucypher/core.py @@ -54,7 +54,7 @@ signature_splitter = BytestringSplitter((Signature, Signature.serialized_size()) capsule_splitter = BytestringSplitter((Capsule, Capsule.serialized_size())) cfrag_splitter = BytestringSplitter((CapsuleFrag, CapsuleFrag.serialized_size())) kfrag_splitter = BytestringSplitter((KeyFrag, KeyFrag.serialized_size())) -checksum_address_splitter = BytestringSplitter((bytes, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant? +checksum_address_splitter = BytestringSplitter((to_checksum_address, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant? class MessageKit(Versioned): @@ -106,8 +106,8 @@ class MessageKit(Versioned): @classmethod def _from_bytes_current(cls, data): splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) - capsule, ciphertext = splitter(data) - return cls(capsule, ciphertext) + capsule, ciphertext, remainder = splitter(data, return_remainder=True) + return cls(capsule, ciphertext), remainder class HRAC: @@ -199,8 +199,8 @@ class AuthorizedKeyFrag(Versioned): @classmethod def _from_bytes_current(cls, data): splitter = BytestringSplitter(signature_splitter, kfrag_splitter) - signature, kfrag = splitter(data) - return cls(signature, kfrag) + signature, kfrag, remainder = splitter(data, return_remainder=True) + return cls(signature, kfrag), remainder def verify(self, hrac: HRAC, @@ -223,6 +223,8 @@ class AuthorizedKeyFrag(Versioned): class EncryptedKeyFrag: + _splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) + @classmethod def author(cls, recipient_key: PublicKey, authorized_kfrag: AuthorizedKeyFrag): # TODO: using Umbral for encryption to avoid introducing more crypto primitives. @@ -242,11 +244,13 @@ class EncryptedKeyFrag: def __bytes__(self): return bytes(self.capsule) + bytes(VariableLengthBytestring(self.ciphertext)) + # Ideally we would define a splitter here that would deserialize into an EKF, + # but due to BSS limitations it cannot be nested (since it doesn't have a definite size). + # So we have to define this helper method and use that instead. @classmethod - def from_bytes(cls, data): - splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) - capsule, ciphertext = splitter(data) - return cls(capsule, ciphertext) + def take(cls, data): + capsule, ciphertext, remainder = cls._splitter(data, return_remainder=True) + return cls(capsule, ciphertext), remainder def __eq__(self, other): return self.capsule == other.capsule and self.ciphertext == other.ciphertext @@ -331,11 +335,16 @@ class TreasureMap(Versioned): def _payload(self) -> bytes: """Returns the unversioned bytes serialized representation of this instance.""" + assigned_kfrags = b''.join( + to_canonical_address(ursula_address) + bytes(encrypted_kfrag) + for ursula_address, encrypted_kfrag in self.destinations.items() + ) + return (self.threshold.to_bytes(1, "big") + bytes(self.hrac) + bytes(self.policy_encrypting_key) + bytes(self.publisher_verifying_key) + - self._nodes_as_bytes()) + bytes(VariableLengthBytestring(assigned_kfrags))) @classmethod def _from_bytes_current(cls, data): @@ -345,20 +354,18 @@ class TreasureMap(Versioned): hrac_splitter, key_splitter, key_splitter, + VariableLengthBytestring, ) - ursula_and_kfrag_payload_splitter = BytestringSplitter( - (to_checksum_address, ETH_ADDRESS_BYTE_LENGTH), - (EncryptedKeyFrag, VariableLengthBytestring), - ) + threshold, hrac, policy_encrypting_key, publisher_verifying_key, assigned_kfrags_bytes, remainder = main_splitter(data, return_remainder=True) - try: - threshold, hrac, policy_encrypting_key, publisher_verifying_key, remainder = main_splitter(data, return_remainder=True) - ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder) - except BytestringSplittingError as e: - raise ValueError('Invalid treasure map contents.') from e - destinations = {u: k for u, k in ursula_and_kfrags} - return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations) + destinations = {} + while assigned_kfrags_bytes: + ursula_address, assigned_kfrags_bytes = checksum_address_splitter(assigned_kfrags_bytes, return_remainder=True) + ekf, assigned_kfrags_bytes = EncryptedKeyFrag.take(assigned_kfrags_bytes) + destinations[ursula_address] = ekf + + return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations), remainder def encrypt(self, signer: Signer, @@ -368,14 +375,6 @@ class TreasureMap(Versioned): recipient_key=recipient_key, treasure_map=self) - def _nodes_as_bytes(self) -> bytes: - nodes_as_bytes = b"" - for ursula_address, encrypted_kfrag in self.destinations.items(): - node_id = to_canonical_address(ursula_address) - kfrag = bytes(VariableLengthBytestring(bytes(encrypted_kfrag))) - nodes_as_bytes += (node_id + kfrag) - return nodes_as_bytes - class AuthorizedTreasureMap(Versioned): @@ -408,13 +407,13 @@ class AuthorizedTreasureMap(Versioned): def _payload(self) -> bytes: """Returns the unversioned bytes serialized representation of this instance.""" return (bytes(self.signature) + - VariableLengthBytestring(bytes(self.treasure_map))) + bytes(self.treasure_map)) @classmethod def _from_bytes_current(cls, data): - splitter = BytestringSplitter(signature_splitter, (TreasureMap, VariableLengthBytestring)) - signature, treasure_map = splitter(data) - return cls(signature, treasure_map) + signature, remainder = signature_splitter(data, return_remainder=True) + treasure_map, remainder = TreasureMap.take(remainder) + return cls(signature, treasure_map), remainder def verify(self, recipient_key: PublicKey, publisher_verifying_key: PublicKey) -> TreasureMap: payload = bytes(recipient_key) + bytes(self.treasure_map) @@ -472,8 +471,8 @@ class EncryptedTreasureMap(Versioned): @classmethod def _from_bytes_current(cls, data): splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) - capsule, ciphertext = splitter(data) - return cls(capsule, ciphertext) + capsule, ciphertext, remainder = splitter(data, return_remainder=True) + return cls(capsule, ciphertext), remainder def __eq__(self, other): return bytes(self) == bytes(other) @@ -518,8 +517,8 @@ class ReencryptionRequest(Versioned): return (bytes(self.hrac) + bytes(self.publisher_verifying_key) + bytes(self.bob_verifying_key) + - VariableLengthBytestring(bytes(self.encrypted_kfrag)) + - b''.join(bytes(capsule) for capsule in self.capsules) + bytes(self.encrypted_kfrag) + + bytes(VariableLengthBytestring(b''.join(bytes(capsule) for capsule in self.capsules))) ) @classmethod @@ -538,12 +537,13 @@ class ReencryptionRequest(Versioned): def _from_bytes_current(cls, data): splitter = (hrac_splitter + key_splitter + - key_splitter + - BytestringSplitter((EncryptedKeyFrag, VariableLengthBytestring))) + key_splitter) - hrac, publisher_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True) - capsules = capsule_splitter.repeat(remainder) - return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules) + hrac, publisher_vk, bob_vk, remainder = splitter(data, return_remainder=True) + ekfrag, remainder = EncryptedKeyFrag.take(remainder) + capsule_bytes, remainder = BytestringSplitter(VariableLengthBytestring)(remainder, return_remainder=True) + capsules = capsule_splitter.repeat(capsule_bytes) + return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules), remainder class ReencryptionResponse(Versioned): @@ -572,7 +572,7 @@ class ReencryptionResponse(Versioned): def _payload(self) -> bytes: """Returns the unversioned bytes serialized representation of this instance.""" - return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags) + return bytes(self.signature) + bytes(VariableLengthBytestring(b''.join(bytes(cfrag) for cfrag in self.cfrags))) @classmethod def _brand(cls) -> bytes: @@ -588,7 +588,8 @@ class ReencryptionResponse(Versioned): @classmethod def _from_bytes_current(cls, data): - signature, cfrags_bytes = signature_splitter(data, return_remainder=True) + splitter = BytestringSplitter(signature_splitter, VariableLengthBytestring) + signature, cfrags_bytes, remainder = splitter(data, return_remainder=True) # We would never send a request with no capsules, so there should be cfrags. # The splitter would fail anyway, this just makes the error message more clear. @@ -596,7 +597,7 @@ class ReencryptionResponse(Versioned): raise ValueError(f"{cls.__name__} contains no cfrags") cfrags = cfrag_splitter.repeat(cfrags_bytes) - return cls(cfrags, signature) + return cls(cfrags, signature), remainder def verify(self, capsules: Sequence[Capsule], @@ -646,7 +647,7 @@ class RetrievalKit(Versioned): def _payload(self) -> bytes: return (bytes(self.capsule) + - b''.join(to_canonical_address(address) for address in self.queried_addresses)) + bytes(VariableLengthBytestring(b''.join(to_canonical_address(address) for address in self.queried_addresses)))) @classmethod def _brand(cls) -> bytes: @@ -662,12 +663,13 @@ class RetrievalKit(Versioned): @classmethod def _from_bytes_current(cls, data): - capsule, remainder = capsule_splitter(data, return_remainder=True) - if remainder: - addresses_as_bytes = checksum_address_splitter.repeat(remainder) + splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) + capsule, addresses_bytes, remainder = splitter(data, return_remainder=True) + if addresses_bytes: + addresses = checksum_address_splitter.repeat(addresses_bytes) else: - addresses_as_bytes = () - return cls(capsule, set(to_checksum_address(address) for address in addresses_as_bytes)) + addresses = () + return cls(capsule, addresses), remainder class RevocationOrder(Versioned): @@ -701,7 +703,7 @@ class RevocationOrder(Versioned): @staticmethod def _signed_payload(ursula_address, encrypted_kfrag): - return to_canonical_address(ursula_address) + bytes(VariableLengthBytestring(bytes(encrypted_kfrag))) + return to_canonical_address(ursula_address) + bytes(encrypted_kfrag) def verify_signature(self, alice_verifying_key: PublicKey) -> bool: """ @@ -726,21 +728,20 @@ class RevocationOrder(Versioned): return {} def _payload(self) -> bytes: - return self._signed_payload(self.ursula_address, self.encrypted_kfrag) + bytes(self.signature) + return bytes(self.signature) + self._signed_payload(self.ursula_address, self.encrypted_kfrag) @classmethod def _from_bytes_current(cls, data): splitter = BytestringSplitter( + signature_splitter, checksum_address_splitter, # ursula canonical address - VariableLengthBytestring, # EncryptedKeyFrag - signature_splitter ) - ursula_canonical_address, ekfrag_bytes, signature = splitter(data) - ekfrag = EncryptedKeyFrag.from_bytes(ekfrag_bytes) - ursula_address = to_checksum_address(ursula_canonical_address) - return cls(ursula_address=ursula_address, - encrypted_kfrag=ekfrag, - signature=signature) + signature, ursula_address, remainder = splitter(data, return_remainder=True) + ekfrag, remainder = EncryptedKeyFrag.take(remainder) + obj = cls(ursula_address=ursula_address, + encrypted_kfrag=ekfrag, + signature=signature) + return obj, remainder class NodeMetadataPayload(NamedTuple): @@ -755,20 +756,19 @@ class NodeMetadataPayload(NamedTuple): host: str port: int - _splitter = BytestringKwargifier( - dict, - public_address=ETH_ADDRESS_BYTE_LENGTH, - domain_bytes=VariableLengthBytestring, - timestamp_epoch=(int, 4, {'byteorder': 'big'}), + _splitter = BytestringSplitter( + (bytes, ETH_ADDRESS_BYTE_LENGTH), # public_address + VariableLengthBytestring, # domain_bytes + (int, 4, {'byteorder': 'big'}), # timestamp_epoch # FIXME: Fixed length doesn't work with federated. It was LENGTH_ECDSA_SIGNATURE_WITH_RECOVERY, - decentralized_identity_evidence=VariableLengthBytestring, + VariableLengthBytestring, # decentralized_identity_evidence - verifying_key=key_splitter, - encrypting_key=key_splitter, - certificate_bytes=VariableLengthBytestring, - host_bytes=VariableLengthBytestring, - port=(int, 2, {'byteorder': 'big'}), + key_splitter, # verifying_key + key_splitter, # encrypting_key + VariableLengthBytestring, # certificate_bytes + VariableLengthBytestring, # host_bytes + (int, 2, {'byteorder': 'big'}), # port ) def __bytes__(self): @@ -784,19 +784,40 @@ class NodeMetadataPayload(NamedTuple): )) return as_bytes + @classmethod + def take(cls, data): + *fields, remainder = cls._splitter(data, return_remainder=True) + + (public_address, + domain, + timestamp_epoch, + decentralized_identity_evidence, + verifying_key, + encrypting_key, + certificate_bytes, + host, + port, + ) = fields + + obj = cls(public_address=public_address, + domain=domain.decode('utf-8'), + timestamp_epoch=timestamp_epoch, + decentralized_identity_evidence=decentralized_identity_evidence, + verifying_key=verifying_key, + encrypting_key=encrypting_key, + certificate_bytes=certificate_bytes, + host=host.decode('utf-8'), + port=port, + ) + + return obj, remainder + @classmethod def from_bytes(cls, data): - result = cls._splitter(data) - return cls(public_address=result['public_address'], - domain=result['domain_bytes'].decode('utf-8'), - timestamp_epoch=result['timestamp_epoch'], - decentralized_identity_evidence=result['decentralized_identity_evidence'], - verifying_key=result['verifying_key'], - encrypting_key=result['encrypting_key'], - certificate_bytes=result['certificate_bytes'], - host=result['host_bytes'].decode('utf-8'), - port=result['port'], - ) + obj, remainder = cls.take(data) + if remainder: + raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}") + return obj class NodeMetadata(Versioned): @@ -838,16 +859,16 @@ class NodeMetadata(Versioned): @classmethod def _from_bytes_current(cls, data: bytes): signature, remainder = signature_splitter(data, return_remainder=True) - payload = NodeMetadataPayload.from_bytes(remainder) - return cls(signature=signature, payload=payload) + payload, remainder = NodeMetadataPayload.take(remainder) + return cls(signature=signature, payload=payload), remainder @classmethod - def batch_from_bytes(cls, data: bytes): - - node_splitter = BytestringSplitter(VariableLengthBytestring) - nodes_vbytes = node_splitter.repeat(data) - - return [cls.from_bytes(node_data) for node_data in nodes_vbytes] + def _batch_from_bytes(cls, data: bytes): + nodes = [] + while data: + node, data = cls.take(data) + nodes.append(node) + return nodes class MetadataRequest(Versioned): @@ -876,29 +897,34 @@ class MetadataRequest(Versioned): def _payload(self): if self.announce_nodes: - nodes_bytes = bytes().join(bytes(VariableLengthBytestring(bytes(n))) for n in self.announce_nodes) + nodes_bytes = b''.join(bytes(n) for n in self.announce_nodes) else: nodes_bytes = b'' - return bytes.fromhex(self.fleet_state_checksum) + nodes_bytes + return bytes.fromhex(self.fleet_state_checksum) + bytes(VariableLengthBytestring(nodes_bytes)) @classmethod def _from_bytes_current(cls, data): - fleet_state_checksum_bytes, nodes_bytes = cls._fleet_state_checksum_splitter(data, return_remainder=True) + splitter = BytestringSplitter( + (bytes, 32), # fleet state checksum + VariableLengthBytestring, + ) + fleet_state_checksum_bytes, nodes_bytes, remainder = splitter(data, return_remainder=True) if nodes_bytes: - nodes = NodeMetadata.batch_from_bytes(nodes_bytes) + nodes = NodeMetadata._batch_from_bytes(nodes_bytes) else: nodes = None - return cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(), - announce_nodes=nodes) + obj = cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(), + announce_nodes=nodes) + return obj, remainder class MetadataResponse(Versioned): _splitter = BytestringSplitter( + signature_splitter, (int, 4, {'byteorder': 'big'}), VariableLengthBytestring, VariableLengthBytestring, - signature_splitter, ) @classmethod @@ -918,7 +944,7 @@ class MetadataResponse(Versioned): @staticmethod def _signed_payload(timestamp_epoch, this_node, other_nodes): timestamp = timestamp_epoch.to_bytes(4, byteorder="big") - nodes_payload = b''.join(bytes(VariableLengthBytestring(bytes(node))) for node in other_nodes) if other_nodes else b'' + nodes_payload = b''.join(bytes(node) for node in other_nodes) if other_nodes else b'' return ( timestamp + bytes(VariableLengthBytestring(bytes(this_node) if this_node else b'')) + @@ -955,14 +981,15 @@ class MetadataResponse(Versioned): def _payload(self): payload = self._signed_payload(self.timestamp_epoch, self.this_node, self.other_nodes) - return payload + bytes(self.signature) + return bytes(self.signature) + payload @classmethod def _from_bytes_current(cls, data: bytes): - timestamp_epoch, maybe_this_node, maybe_other_nodes, signature = cls._splitter(data) + signature, timestamp_epoch, maybe_this_node, maybe_other_nodes, remainder = cls._splitter(data, return_remainder=True) this_node = NodeMetadata.from_bytes(maybe_this_node) if maybe_this_node else None - other_nodes = NodeMetadata.batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None - return cls(signature=signature, - timestamp_epoch=timestamp_epoch, - this_node=this_node, - other_nodes=other_nodes) + other_nodes = NodeMetadata._batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None + obj = cls(signature=signature, + timestamp_epoch=timestamp_epoch, + this_node=this_node, + other_nodes=other_nodes) + return obj, remainder diff --git a/nucypher/utilities/versioning.py b/nucypher/utilities/versioning.py index 6914b11cd..af65beaad 100644 --- a/nucypher/utilities/versioning.py +++ b/nucypher/utilities/versioning.py @@ -93,12 +93,24 @@ class Versioned(ABC): raise NotImplementedError @classmethod - def from_bytes(cls, data: bytes): - """"Public deserialization API""" + def take(cls, data: bytes): + """ + Deserializes the object from the given bytestring + and returns the object and the remainder of the bytestring. + """ brand, version, payload = cls._parse_header(data) version = cls._resolve_version(version=version) handlers = cls._deserializers() - return handlers[version](payload) + obj, remainder = handlers[version](payload) + return obj, remainder + + @classmethod + def from_bytes(cls, data: bytes): + """"Public deserialization API""" + obj, remainder = cls.take(data) + if remainder: + raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}") + return obj @classmethod def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]: diff --git a/tests/integration/characters/test_specifications.py b/tests/integration/characters/test_specifications.py index 332f90caa..16d0eb443 100644 --- a/tests/integration/characters/test_specifications.py +++ b/tests/integration/characters/test_specifications.py @@ -131,7 +131,7 @@ def test_treasure_map_validation(enacted_federated_policy, UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64}) assert "Could not convert input for tmap to a TreasureMap" in str(e) - assert "Invalid treasure map contents." in str(e) + assert "Can't split a message with more bytes than the original splittable." in str(e) # a valid treasuremap decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map, diff --git a/tests/unit/test_versioning.py b/tests/unit/test_versioning.py index 87e815e1e..354ec6092 100644 --- a/tests/unit/test_versioning.py +++ b/tests/unit/test_versioning.py @@ -35,7 +35,7 @@ def _check_valid_version_tuple(version: Any, cls: Type): class A(Versioned): - def __init__(self, x): + def __init__(self, x: int): self.x = x @classmethod @@ -47,7 +47,7 @@ class A(Versioned): return 2, 1 def _payload(self) -> bytes: - return bytes(self.x) + return self.x.to_bytes(1, 'big') @classmethod def _old_version_handlers(cls): @@ -57,11 +57,16 @@ class A(Versioned): @classmethod def _from_bytes_v2_0(cls, data): - return cls(int(data, 16)) # then we switched to the hexadecimal + # v2.0 saved a 4 byte integer in hex format + int_hex, remainder = data[:2], data[2:] + int_bytes = bytes.fromhex(int_hex.decode()) + return cls(int.from_bytes(int_bytes, 'big')), remainder @classmethod def _from_bytes_current(cls, data): - return cls(str(int(data, 16))) # but now we use a string representation for some reason + # v2.1 saves a 4 byte integer as 4 bytes + int_bytes, remainder = data[:1], data[1:] + return cls(int.from_bytes(int_bytes, 'big')), remainder def test_unique_branding(): @@ -206,13 +211,13 @@ def test_current_minor_version_handler_routing(mocker): current_spy = mocker.spy(A, "_from_bytes_current") v2_0_spy = mocker.spy(A, "_from_bytes_v2_0") - v2_1_data = b'ABCD\x00\x02\x00\x0112' + v2_1_data = b'ABCD\x00\x02\x00\x01\x12' a = A.from_bytes(v2_1_data) - assert a.x == '18' + assert a.x == 18 # Current version was correctly routed to the v2.1 handler. assert current_spy.call_count == 1 - current_spy.assert_called_with(b'12') + current_spy.assert_called_with(b'\x12') assert not v2_0_spy.call_count @@ -220,12 +225,12 @@ def test_future_minor_version_handler_routing(mocker): current_spy = mocker.spy(A, "_from_bytes_current") v2_0_spy = mocker.spy(A, "_from_bytes_v2_0") - v2_2_data = b'ABCD\x00\x02\x02\x0112' + v2_2_data = b'ABCD\x00\x02\x02\x01\x12' a = A.from_bytes(v2_2_data) - assert a.x == '18' + assert a.x == 18 # Future minor version was correctly routed to # the current minor version handler. assert current_spy.call_count == 1 - current_spy.assert_called_with(b'12') + current_spy.assert_called_with(b'\x12') assert not v2_0_spy.call_count