Normalize the usage of VariableLengthBytestrings

Allow for every Versioned type to be able to deserialize itself from the bytestring.
pull/2809/head
Bogdan Opanchuk 2021-10-30 12:10:59 -07:00
parent 1114891680
commit 944d3373e7
4 changed files with 164 additions and 120 deletions

View File

@ -54,7 +54,7 @@ signature_splitter = BytestringSplitter((Signature, Signature.serialized_size())
capsule_splitter = BytestringSplitter((Capsule, Capsule.serialized_size())) capsule_splitter = BytestringSplitter((Capsule, Capsule.serialized_size()))
cfrag_splitter = BytestringSplitter((CapsuleFrag, CapsuleFrag.serialized_size())) cfrag_splitter = BytestringSplitter((CapsuleFrag, CapsuleFrag.serialized_size()))
kfrag_splitter = BytestringSplitter((KeyFrag, KeyFrag.serialized_size())) kfrag_splitter = BytestringSplitter((KeyFrag, KeyFrag.serialized_size()))
checksum_address_splitter = BytestringSplitter((bytes, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant? checksum_address_splitter = BytestringSplitter((to_checksum_address, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant?
class MessageKit(Versioned): class MessageKit(Versioned):
@ -106,8 +106,8 @@ class MessageKit(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, ciphertext = splitter(data) capsule, ciphertext, remainder = splitter(data, return_remainder=True)
return cls(capsule, ciphertext) return cls(capsule, ciphertext), remainder
class HRAC: class HRAC:
@ -199,8 +199,8 @@ class AuthorizedKeyFrag(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = BytestringSplitter(signature_splitter, kfrag_splitter) splitter = BytestringSplitter(signature_splitter, kfrag_splitter)
signature, kfrag = splitter(data) signature, kfrag, remainder = splitter(data, return_remainder=True)
return cls(signature, kfrag) return cls(signature, kfrag), remainder
def verify(self, def verify(self,
hrac: HRAC, hrac: HRAC,
@ -223,6 +223,8 @@ class AuthorizedKeyFrag(Versioned):
class EncryptedKeyFrag: class EncryptedKeyFrag:
_splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
@classmethod @classmethod
def author(cls, recipient_key: PublicKey, authorized_kfrag: AuthorizedKeyFrag): def author(cls, recipient_key: PublicKey, authorized_kfrag: AuthorizedKeyFrag):
# TODO: using Umbral for encryption to avoid introducing more crypto primitives. # TODO: using Umbral for encryption to avoid introducing more crypto primitives.
@ -242,11 +244,13 @@ class EncryptedKeyFrag:
def __bytes__(self): def __bytes__(self):
return bytes(self.capsule) + bytes(VariableLengthBytestring(self.ciphertext)) return bytes(self.capsule) + bytes(VariableLengthBytestring(self.ciphertext))
# Ideally we would define a splitter here that would deserialize into an EKF,
# but due to BSS limitations it cannot be nested (since it doesn't have a definite size).
# So we have to define this helper method and use that instead.
@classmethod @classmethod
def from_bytes(cls, data): def take(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) capsule, ciphertext, remainder = cls._splitter(data, return_remainder=True)
capsule, ciphertext = splitter(data) return cls(capsule, ciphertext), remainder
return cls(capsule, ciphertext)
def __eq__(self, other): def __eq__(self, other):
return self.capsule == other.capsule and self.ciphertext == other.ciphertext return self.capsule == other.capsule and self.ciphertext == other.ciphertext
@ -331,11 +335,16 @@ class TreasureMap(Versioned):
def _payload(self) -> bytes: def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance.""" """Returns the unversioned bytes serialized representation of this instance."""
assigned_kfrags = b''.join(
to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
for ursula_address, encrypted_kfrag in self.destinations.items()
)
return (self.threshold.to_bytes(1, "big") + return (self.threshold.to_bytes(1, "big") +
bytes(self.hrac) + bytes(self.hrac) +
bytes(self.policy_encrypting_key) + bytes(self.policy_encrypting_key) +
bytes(self.publisher_verifying_key) + bytes(self.publisher_verifying_key) +
self._nodes_as_bytes()) bytes(VariableLengthBytestring(assigned_kfrags)))
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
@ -345,20 +354,18 @@ class TreasureMap(Versioned):
hrac_splitter, hrac_splitter,
key_splitter, key_splitter,
key_splitter, key_splitter,
VariableLengthBytestring,
) )
ursula_and_kfrag_payload_splitter = BytestringSplitter( threshold, hrac, policy_encrypting_key, publisher_verifying_key, assigned_kfrags_bytes, remainder = main_splitter(data, return_remainder=True)
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
(EncryptedKeyFrag, VariableLengthBytestring),
)
try: destinations = {}
threshold, hrac, policy_encrypting_key, publisher_verifying_key, remainder = main_splitter(data, return_remainder=True) while assigned_kfrags_bytes:
ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder) ursula_address, assigned_kfrags_bytes = checksum_address_splitter(assigned_kfrags_bytes, return_remainder=True)
except BytestringSplittingError as e: ekf, assigned_kfrags_bytes = EncryptedKeyFrag.take(assigned_kfrags_bytes)
raise ValueError('Invalid treasure map contents.') from e destinations[ursula_address] = ekf
destinations = {u: k for u, k in ursula_and_kfrags}
return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations) return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations), remainder
def encrypt(self, def encrypt(self,
signer: Signer, signer: Signer,
@ -368,14 +375,6 @@ class TreasureMap(Versioned):
recipient_key=recipient_key, recipient_key=recipient_key,
treasure_map=self) treasure_map=self)
def _nodes_as_bytes(self) -> bytes:
nodes_as_bytes = b""
for ursula_address, encrypted_kfrag in self.destinations.items():
node_id = to_canonical_address(ursula_address)
kfrag = bytes(VariableLengthBytestring(bytes(encrypted_kfrag)))
nodes_as_bytes += (node_id + kfrag)
return nodes_as_bytes
class AuthorizedTreasureMap(Versioned): class AuthorizedTreasureMap(Versioned):
@ -408,13 +407,13 @@ class AuthorizedTreasureMap(Versioned):
def _payload(self) -> bytes: def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance.""" """Returns the unversioned bytes serialized representation of this instance."""
return (bytes(self.signature) + return (bytes(self.signature) +
VariableLengthBytestring(bytes(self.treasure_map))) bytes(self.treasure_map))
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = BytestringSplitter(signature_splitter, (TreasureMap, VariableLengthBytestring)) signature, remainder = signature_splitter(data, return_remainder=True)
signature, treasure_map = splitter(data) treasure_map, remainder = TreasureMap.take(remainder)
return cls(signature, treasure_map) return cls(signature, treasure_map), remainder
def verify(self, recipient_key: PublicKey, publisher_verifying_key: PublicKey) -> TreasureMap: def verify(self, recipient_key: PublicKey, publisher_verifying_key: PublicKey) -> TreasureMap:
payload = bytes(recipient_key) + bytes(self.treasure_map) payload = bytes(recipient_key) + bytes(self.treasure_map)
@ -472,8 +471,8 @@ class EncryptedTreasureMap(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring) splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, ciphertext = splitter(data) capsule, ciphertext, remainder = splitter(data, return_remainder=True)
return cls(capsule, ciphertext) return cls(capsule, ciphertext), remainder
def __eq__(self, other): def __eq__(self, other):
return bytes(self) == bytes(other) return bytes(self) == bytes(other)
@ -518,8 +517,8 @@ class ReencryptionRequest(Versioned):
return (bytes(self.hrac) + return (bytes(self.hrac) +
bytes(self.publisher_verifying_key) + bytes(self.publisher_verifying_key) +
bytes(self.bob_verifying_key) + bytes(self.bob_verifying_key) +
VariableLengthBytestring(bytes(self.encrypted_kfrag)) + bytes(self.encrypted_kfrag) +
b''.join(bytes(capsule) for capsule in self.capsules) bytes(VariableLengthBytestring(b''.join(bytes(capsule) for capsule in self.capsules)))
) )
@classmethod @classmethod
@ -538,12 +537,13 @@ class ReencryptionRequest(Versioned):
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = (hrac_splitter + splitter = (hrac_splitter +
key_splitter + key_splitter +
key_splitter + key_splitter)
BytestringSplitter((EncryptedKeyFrag, VariableLengthBytestring)))
hrac, publisher_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True) hrac, publisher_vk, bob_vk, remainder = splitter(data, return_remainder=True)
capsules = capsule_splitter.repeat(remainder) ekfrag, remainder = EncryptedKeyFrag.take(remainder)
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules) capsule_bytes, remainder = BytestringSplitter(VariableLengthBytestring)(remainder, return_remainder=True)
capsules = capsule_splitter.repeat(capsule_bytes)
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules), remainder
class ReencryptionResponse(Versioned): class ReencryptionResponse(Versioned):
@ -572,7 +572,7 @@ class ReencryptionResponse(Versioned):
def _payload(self) -> bytes: def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance.""" """Returns the unversioned bytes serialized representation of this instance."""
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags) return bytes(self.signature) + bytes(VariableLengthBytestring(b''.join(bytes(cfrag) for cfrag in self.cfrags)))
@classmethod @classmethod
def _brand(cls) -> bytes: def _brand(cls) -> bytes:
@ -588,7 +588,8 @@ class ReencryptionResponse(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
signature, cfrags_bytes = signature_splitter(data, return_remainder=True) splitter = BytestringSplitter(signature_splitter, VariableLengthBytestring)
signature, cfrags_bytes, remainder = splitter(data, return_remainder=True)
# We would never send a request with no capsules, so there should be cfrags. # We would never send a request with no capsules, so there should be cfrags.
# The splitter would fail anyway, this just makes the error message more clear. # The splitter would fail anyway, this just makes the error message more clear.
@ -596,7 +597,7 @@ class ReencryptionResponse(Versioned):
raise ValueError(f"{cls.__name__} contains no cfrags") raise ValueError(f"{cls.__name__} contains no cfrags")
cfrags = cfrag_splitter.repeat(cfrags_bytes) cfrags = cfrag_splitter.repeat(cfrags_bytes)
return cls(cfrags, signature) return cls(cfrags, signature), remainder
def verify(self, def verify(self,
capsules: Sequence[Capsule], capsules: Sequence[Capsule],
@ -646,7 +647,7 @@ class RetrievalKit(Versioned):
def _payload(self) -> bytes: def _payload(self) -> bytes:
return (bytes(self.capsule) + return (bytes(self.capsule) +
b''.join(to_canonical_address(address) for address in self.queried_addresses)) bytes(VariableLengthBytestring(b''.join(to_canonical_address(address) for address in self.queried_addresses))))
@classmethod @classmethod
def _brand(cls) -> bytes: def _brand(cls) -> bytes:
@ -662,12 +663,13 @@ class RetrievalKit(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
capsule, remainder = capsule_splitter(data, return_remainder=True) splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
if remainder: capsule, addresses_bytes, remainder = splitter(data, return_remainder=True)
addresses_as_bytes = checksum_address_splitter.repeat(remainder) if addresses_bytes:
addresses = checksum_address_splitter.repeat(addresses_bytes)
else: else:
addresses_as_bytes = () addresses = ()
return cls(capsule, set(to_checksum_address(address) for address in addresses_as_bytes)) return cls(capsule, addresses), remainder
class RevocationOrder(Versioned): class RevocationOrder(Versioned):
@ -701,7 +703,7 @@ class RevocationOrder(Versioned):
@staticmethod @staticmethod
def _signed_payload(ursula_address, encrypted_kfrag): def _signed_payload(ursula_address, encrypted_kfrag):
return to_canonical_address(ursula_address) + bytes(VariableLengthBytestring(bytes(encrypted_kfrag))) return to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
def verify_signature(self, alice_verifying_key: PublicKey) -> bool: def verify_signature(self, alice_verifying_key: PublicKey) -> bool:
""" """
@ -726,21 +728,20 @@ class RevocationOrder(Versioned):
return {} return {}
def _payload(self) -> bytes: def _payload(self) -> bytes:
return self._signed_payload(self.ursula_address, self.encrypted_kfrag) + bytes(self.signature) return bytes(self.signature) + self._signed_payload(self.ursula_address, self.encrypted_kfrag)
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
splitter = BytestringSplitter( splitter = BytestringSplitter(
signature_splitter,
checksum_address_splitter, # ursula canonical address checksum_address_splitter, # ursula canonical address
VariableLengthBytestring, # EncryptedKeyFrag
signature_splitter
) )
ursula_canonical_address, ekfrag_bytes, signature = splitter(data) signature, ursula_address, remainder = splitter(data, return_remainder=True)
ekfrag = EncryptedKeyFrag.from_bytes(ekfrag_bytes) ekfrag, remainder = EncryptedKeyFrag.take(remainder)
ursula_address = to_checksum_address(ursula_canonical_address) obj = cls(ursula_address=ursula_address,
return cls(ursula_address=ursula_address, encrypted_kfrag=ekfrag,
encrypted_kfrag=ekfrag, signature=signature)
signature=signature) return obj, remainder
class NodeMetadataPayload(NamedTuple): class NodeMetadataPayload(NamedTuple):
@ -755,20 +756,19 @@ class NodeMetadataPayload(NamedTuple):
host: str host: str
port: int port: int
_splitter = BytestringKwargifier( _splitter = BytestringSplitter(
dict, (bytes, ETH_ADDRESS_BYTE_LENGTH), # public_address
public_address=ETH_ADDRESS_BYTE_LENGTH, VariableLengthBytestring, # domain_bytes
domain_bytes=VariableLengthBytestring, (int, 4, {'byteorder': 'big'}), # timestamp_epoch
timestamp_epoch=(int, 4, {'byteorder': 'big'}),
# FIXME: Fixed length doesn't work with federated. It was LENGTH_ECDSA_SIGNATURE_WITH_RECOVERY, # FIXME: Fixed length doesn't work with federated. It was LENGTH_ECDSA_SIGNATURE_WITH_RECOVERY,
decentralized_identity_evidence=VariableLengthBytestring, VariableLengthBytestring, # decentralized_identity_evidence
verifying_key=key_splitter, key_splitter, # verifying_key
encrypting_key=key_splitter, key_splitter, # encrypting_key
certificate_bytes=VariableLengthBytestring, VariableLengthBytestring, # certificate_bytes
host_bytes=VariableLengthBytestring, VariableLengthBytestring, # host_bytes
port=(int, 2, {'byteorder': 'big'}), (int, 2, {'byteorder': 'big'}), # port
) )
def __bytes__(self): def __bytes__(self):
@ -784,19 +784,40 @@ class NodeMetadataPayload(NamedTuple):
)) ))
return as_bytes return as_bytes
@classmethod
def take(cls, data):
*fields, remainder = cls._splitter(data, return_remainder=True)
(public_address,
domain,
timestamp_epoch,
decentralized_identity_evidence,
verifying_key,
encrypting_key,
certificate_bytes,
host,
port,
) = fields
obj = cls(public_address=public_address,
domain=domain.decode('utf-8'),
timestamp_epoch=timestamp_epoch,
decentralized_identity_evidence=decentralized_identity_evidence,
verifying_key=verifying_key,
encrypting_key=encrypting_key,
certificate_bytes=certificate_bytes,
host=host.decode('utf-8'),
port=port,
)
return obj, remainder
@classmethod @classmethod
def from_bytes(cls, data): def from_bytes(cls, data):
result = cls._splitter(data) obj, remainder = cls.take(data)
return cls(public_address=result['public_address'], if remainder:
domain=result['domain_bytes'].decode('utf-8'), raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
timestamp_epoch=result['timestamp_epoch'], return obj
decentralized_identity_evidence=result['decentralized_identity_evidence'],
verifying_key=result['verifying_key'],
encrypting_key=result['encrypting_key'],
certificate_bytes=result['certificate_bytes'],
host=result['host_bytes'].decode('utf-8'),
port=result['port'],
)
class NodeMetadata(Versioned): class NodeMetadata(Versioned):
@ -838,16 +859,16 @@ class NodeMetadata(Versioned):
@classmethod @classmethod
def _from_bytes_current(cls, data: bytes): def _from_bytes_current(cls, data: bytes):
signature, remainder = signature_splitter(data, return_remainder=True) signature, remainder = signature_splitter(data, return_remainder=True)
payload = NodeMetadataPayload.from_bytes(remainder) payload, remainder = NodeMetadataPayload.take(remainder)
return cls(signature=signature, payload=payload) return cls(signature=signature, payload=payload), remainder
@classmethod @classmethod
def batch_from_bytes(cls, data: bytes): def _batch_from_bytes(cls, data: bytes):
nodes = []
node_splitter = BytestringSplitter(VariableLengthBytestring) while data:
nodes_vbytes = node_splitter.repeat(data) node, data = cls.take(data)
nodes.append(node)
return [cls.from_bytes(node_data) for node_data in nodes_vbytes] return nodes
class MetadataRequest(Versioned): class MetadataRequest(Versioned):
@ -876,29 +897,34 @@ class MetadataRequest(Versioned):
def _payload(self): def _payload(self):
if self.announce_nodes: if self.announce_nodes:
nodes_bytes = bytes().join(bytes(VariableLengthBytestring(bytes(n))) for n in self.announce_nodes) nodes_bytes = b''.join(bytes(n) for n in self.announce_nodes)
else: else:
nodes_bytes = b'' nodes_bytes = b''
return bytes.fromhex(self.fleet_state_checksum) + nodes_bytes return bytes.fromhex(self.fleet_state_checksum) + bytes(VariableLengthBytestring(nodes_bytes))
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
fleet_state_checksum_bytes, nodes_bytes = cls._fleet_state_checksum_splitter(data, return_remainder=True) splitter = BytestringSplitter(
(bytes, 32), # fleet state checksum
VariableLengthBytestring,
)
fleet_state_checksum_bytes, nodes_bytes, remainder = splitter(data, return_remainder=True)
if nodes_bytes: if nodes_bytes:
nodes = NodeMetadata.batch_from_bytes(nodes_bytes) nodes = NodeMetadata._batch_from_bytes(nodes_bytes)
else: else:
nodes = None nodes = None
return cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(), obj = cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(),
announce_nodes=nodes) announce_nodes=nodes)
return obj, remainder
class MetadataResponse(Versioned): class MetadataResponse(Versioned):
_splitter = BytestringSplitter( _splitter = BytestringSplitter(
signature_splitter,
(int, 4, {'byteorder': 'big'}), (int, 4, {'byteorder': 'big'}),
VariableLengthBytestring, VariableLengthBytestring,
VariableLengthBytestring, VariableLengthBytestring,
signature_splitter,
) )
@classmethod @classmethod
@ -918,7 +944,7 @@ class MetadataResponse(Versioned):
@staticmethod @staticmethod
def _signed_payload(timestamp_epoch, this_node, other_nodes): def _signed_payload(timestamp_epoch, this_node, other_nodes):
timestamp = timestamp_epoch.to_bytes(4, byteorder="big") timestamp = timestamp_epoch.to_bytes(4, byteorder="big")
nodes_payload = b''.join(bytes(VariableLengthBytestring(bytes(node))) for node in other_nodes) if other_nodes else b'' nodes_payload = b''.join(bytes(node) for node in other_nodes) if other_nodes else b''
return ( return (
timestamp + timestamp +
bytes(VariableLengthBytestring(bytes(this_node) if this_node else b'')) + bytes(VariableLengthBytestring(bytes(this_node) if this_node else b'')) +
@ -955,14 +981,15 @@ class MetadataResponse(Versioned):
def _payload(self): def _payload(self):
payload = self._signed_payload(self.timestamp_epoch, self.this_node, self.other_nodes) payload = self._signed_payload(self.timestamp_epoch, self.this_node, self.other_nodes)
return payload + bytes(self.signature) return bytes(self.signature) + payload
@classmethod @classmethod
def _from_bytes_current(cls, data: bytes): def _from_bytes_current(cls, data: bytes):
timestamp_epoch, maybe_this_node, maybe_other_nodes, signature = cls._splitter(data) signature, timestamp_epoch, maybe_this_node, maybe_other_nodes, remainder = cls._splitter(data, return_remainder=True)
this_node = NodeMetadata.from_bytes(maybe_this_node) if maybe_this_node else None this_node = NodeMetadata.from_bytes(maybe_this_node) if maybe_this_node else None
other_nodes = NodeMetadata.batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None other_nodes = NodeMetadata._batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None
return cls(signature=signature, obj = cls(signature=signature,
timestamp_epoch=timestamp_epoch, timestamp_epoch=timestamp_epoch,
this_node=this_node, this_node=this_node,
other_nodes=other_nodes) other_nodes=other_nodes)
return obj, remainder

View File

@ -93,12 +93,24 @@ class Versioned(ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def from_bytes(cls, data: bytes): def take(cls, data: bytes):
""""Public deserialization API""" """
Deserializes the object from the given bytestring
and returns the object and the remainder of the bytestring.
"""
brand, version, payload = cls._parse_header(data) brand, version, payload = cls._parse_header(data)
version = cls._resolve_version(version=version) version = cls._resolve_version(version=version)
handlers = cls._deserializers() handlers = cls._deserializers()
return handlers[version](payload) obj, remainder = handlers[version](payload)
return obj, remainder
@classmethod
def from_bytes(cls, data: bytes):
""""Public deserialization API"""
obj, remainder = cls.take(data)
if remainder:
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
return obj
@classmethod @classmethod
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]: def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:

View File

@ -131,7 +131,7 @@ def test_treasure_map_validation(enacted_federated_policy,
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64}) UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64})
assert "Could not convert input for tmap to a TreasureMap" in str(e) assert "Could not convert input for tmap to a TreasureMap" in str(e)
assert "Invalid treasure map contents." in str(e) assert "Can't split a message with more bytes than the original splittable." in str(e)
# a valid treasuremap # a valid treasuremap
decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map, decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map,

View File

@ -35,7 +35,7 @@ def _check_valid_version_tuple(version: Any, cls: Type):
class A(Versioned): class A(Versioned):
def __init__(self, x): def __init__(self, x: int):
self.x = x self.x = x
@classmethod @classmethod
@ -47,7 +47,7 @@ class A(Versioned):
return 2, 1 return 2, 1
def _payload(self) -> bytes: def _payload(self) -> bytes:
return bytes(self.x) return self.x.to_bytes(1, 'big')
@classmethod @classmethod
def _old_version_handlers(cls): def _old_version_handlers(cls):
@ -57,11 +57,16 @@ class A(Versioned):
@classmethod @classmethod
def _from_bytes_v2_0(cls, data): def _from_bytes_v2_0(cls, data):
return cls(int(data, 16)) # then we switched to the hexadecimal # v2.0 saved a 4 byte integer in hex format
int_hex, remainder = data[:2], data[2:]
int_bytes = bytes.fromhex(int_hex.decode())
return cls(int.from_bytes(int_bytes, 'big')), remainder
@classmethod @classmethod
def _from_bytes_current(cls, data): def _from_bytes_current(cls, data):
return cls(str(int(data, 16))) # but now we use a string representation for some reason # v2.1 saves a 4 byte integer as 4 bytes
int_bytes, remainder = data[:1], data[1:]
return cls(int.from_bytes(int_bytes, 'big')), remainder
def test_unique_branding(): def test_unique_branding():
@ -206,13 +211,13 @@ def test_current_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current") current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0") v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_1_data = b'ABCD\x00\x02\x00\x0112' v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
a = A.from_bytes(v2_1_data) a = A.from_bytes(v2_1_data)
assert a.x == '18' assert a.x == 18
# Current version was correctly routed to the v2.1 handler. # Current version was correctly routed to the v2.1 handler.
assert current_spy.call_count == 1 assert current_spy.call_count == 1
current_spy.assert_called_with(b'12') current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count assert not v2_0_spy.call_count
@ -220,12 +225,12 @@ def test_future_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current") current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0") v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_2_data = b'ABCD\x00\x02\x02\x0112' v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
a = A.from_bytes(v2_2_data) a = A.from_bytes(v2_2_data)
assert a.x == '18' assert a.x == 18
# Future minor version was correctly routed to # Future minor version was correctly routed to
# the current minor version handler. # the current minor version handler.
assert current_spy.call_count == 1 assert current_spy.call_count == 1
current_spy.assert_called_with(b'12') current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count assert not v2_0_spy.call_count