mirror of https://github.com/nucypher/nucypher.git
Normalize the usage of VariableLengthBytestrings
Allow for every Versioned type to be able to deserialize itself from the bytestring.pull/2809/head
parent
1114891680
commit
944d3373e7
239
nucypher/core.py
239
nucypher/core.py
|
@ -54,7 +54,7 @@ signature_splitter = BytestringSplitter((Signature, Signature.serialized_size())
|
|||
capsule_splitter = BytestringSplitter((Capsule, Capsule.serialized_size()))
|
||||
cfrag_splitter = BytestringSplitter((CapsuleFrag, CapsuleFrag.serialized_size()))
|
||||
kfrag_splitter = BytestringSplitter((KeyFrag, KeyFrag.serialized_size()))
|
||||
checksum_address_splitter = BytestringSplitter((bytes, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant?
|
||||
checksum_address_splitter = BytestringSplitter((to_checksum_address, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant?
|
||||
|
||||
|
||||
class MessageKit(Versioned):
|
||||
|
@ -106,8 +106,8 @@ class MessageKit(Versioned):
|
|||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
|
||||
capsule, ciphertext = splitter(data)
|
||||
return cls(capsule, ciphertext)
|
||||
capsule, ciphertext, remainder = splitter(data, return_remainder=True)
|
||||
return cls(capsule, ciphertext), remainder
|
||||
|
||||
|
||||
class HRAC:
|
||||
|
@ -199,8 +199,8 @@ class AuthorizedKeyFrag(Versioned):
|
|||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(signature_splitter, kfrag_splitter)
|
||||
signature, kfrag = splitter(data)
|
||||
return cls(signature, kfrag)
|
||||
signature, kfrag, remainder = splitter(data, return_remainder=True)
|
||||
return cls(signature, kfrag), remainder
|
||||
|
||||
def verify(self,
|
||||
hrac: HRAC,
|
||||
|
@ -223,6 +223,8 @@ class AuthorizedKeyFrag(Versioned):
|
|||
|
||||
class EncryptedKeyFrag:
|
||||
|
||||
_splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
|
||||
|
||||
@classmethod
|
||||
def author(cls, recipient_key: PublicKey, authorized_kfrag: AuthorizedKeyFrag):
|
||||
# TODO: using Umbral for encryption to avoid introducing more crypto primitives.
|
||||
|
@ -242,11 +244,13 @@ class EncryptedKeyFrag:
|
|||
def __bytes__(self):
|
||||
return bytes(self.capsule) + bytes(VariableLengthBytestring(self.ciphertext))
|
||||
|
||||
# Ideally we would define a splitter here that would deserialize into an EKF,
|
||||
# but due to BSS limitations it cannot be nested (since it doesn't have a definite size).
|
||||
# So we have to define this helper method and use that instead.
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
|
||||
capsule, ciphertext = splitter(data)
|
||||
return cls(capsule, ciphertext)
|
||||
def take(cls, data):
|
||||
capsule, ciphertext, remainder = cls._splitter(data, return_remainder=True)
|
||||
return cls(capsule, ciphertext), remainder
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.capsule == other.capsule and self.ciphertext == other.ciphertext
|
||||
|
@ -331,11 +335,16 @@ class TreasureMap(Versioned):
|
|||
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
assigned_kfrags = b''.join(
|
||||
to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
|
||||
for ursula_address, encrypted_kfrag in self.destinations.items()
|
||||
)
|
||||
|
||||
return (self.threshold.to_bytes(1, "big") +
|
||||
bytes(self.hrac) +
|
||||
bytes(self.policy_encrypting_key) +
|
||||
bytes(self.publisher_verifying_key) +
|
||||
self._nodes_as_bytes())
|
||||
bytes(VariableLengthBytestring(assigned_kfrags)))
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
|
@ -345,20 +354,18 @@ class TreasureMap(Versioned):
|
|||
hrac_splitter,
|
||||
key_splitter,
|
||||
key_splitter,
|
||||
VariableLengthBytestring,
|
||||
)
|
||||
|
||||
ursula_and_kfrag_payload_splitter = BytestringSplitter(
|
||||
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
|
||||
(EncryptedKeyFrag, VariableLengthBytestring),
|
||||
)
|
||||
threshold, hrac, policy_encrypting_key, publisher_verifying_key, assigned_kfrags_bytes, remainder = main_splitter(data, return_remainder=True)
|
||||
|
||||
try:
|
||||
threshold, hrac, policy_encrypting_key, publisher_verifying_key, remainder = main_splitter(data, return_remainder=True)
|
||||
ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder)
|
||||
except BytestringSplittingError as e:
|
||||
raise ValueError('Invalid treasure map contents.') from e
|
||||
destinations = {u: k for u, k in ursula_and_kfrags}
|
||||
return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations)
|
||||
destinations = {}
|
||||
while assigned_kfrags_bytes:
|
||||
ursula_address, assigned_kfrags_bytes = checksum_address_splitter(assigned_kfrags_bytes, return_remainder=True)
|
||||
ekf, assigned_kfrags_bytes = EncryptedKeyFrag.take(assigned_kfrags_bytes)
|
||||
destinations[ursula_address] = ekf
|
||||
|
||||
return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations), remainder
|
||||
|
||||
def encrypt(self,
|
||||
signer: Signer,
|
||||
|
@ -368,14 +375,6 @@ class TreasureMap(Versioned):
|
|||
recipient_key=recipient_key,
|
||||
treasure_map=self)
|
||||
|
||||
def _nodes_as_bytes(self) -> bytes:
|
||||
nodes_as_bytes = b""
|
||||
for ursula_address, encrypted_kfrag in self.destinations.items():
|
||||
node_id = to_canonical_address(ursula_address)
|
||||
kfrag = bytes(VariableLengthBytestring(bytes(encrypted_kfrag)))
|
||||
nodes_as_bytes += (node_id + kfrag)
|
||||
return nodes_as_bytes
|
||||
|
||||
|
||||
class AuthorizedTreasureMap(Versioned):
|
||||
|
||||
|
@ -408,13 +407,13 @@ class AuthorizedTreasureMap(Versioned):
|
|||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return (bytes(self.signature) +
|
||||
VariableLengthBytestring(bytes(self.treasure_map)))
|
||||
bytes(self.treasure_map))
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(signature_splitter, (TreasureMap, VariableLengthBytestring))
|
||||
signature, treasure_map = splitter(data)
|
||||
return cls(signature, treasure_map)
|
||||
signature, remainder = signature_splitter(data, return_remainder=True)
|
||||
treasure_map, remainder = TreasureMap.take(remainder)
|
||||
return cls(signature, treasure_map), remainder
|
||||
|
||||
def verify(self, recipient_key: PublicKey, publisher_verifying_key: PublicKey) -> TreasureMap:
|
||||
payload = bytes(recipient_key) + bytes(self.treasure_map)
|
||||
|
@ -472,8 +471,8 @@ class EncryptedTreasureMap(Versioned):
|
|||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
|
||||
capsule, ciphertext = splitter(data)
|
||||
return cls(capsule, ciphertext)
|
||||
capsule, ciphertext, remainder = splitter(data, return_remainder=True)
|
||||
return cls(capsule, ciphertext), remainder
|
||||
|
||||
def __eq__(self, other):
|
||||
return bytes(self) == bytes(other)
|
||||
|
@ -518,8 +517,8 @@ class ReencryptionRequest(Versioned):
|
|||
return (bytes(self.hrac) +
|
||||
bytes(self.publisher_verifying_key) +
|
||||
bytes(self.bob_verifying_key) +
|
||||
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
|
||||
b''.join(bytes(capsule) for capsule in self.capsules)
|
||||
bytes(self.encrypted_kfrag) +
|
||||
bytes(VariableLengthBytestring(b''.join(bytes(capsule) for capsule in self.capsules)))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -538,12 +537,13 @@ class ReencryptionRequest(Versioned):
|
|||
def _from_bytes_current(cls, data):
|
||||
splitter = (hrac_splitter +
|
||||
key_splitter +
|
||||
key_splitter +
|
||||
BytestringSplitter((EncryptedKeyFrag, VariableLengthBytestring)))
|
||||
key_splitter)
|
||||
|
||||
hrac, publisher_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True)
|
||||
capsules = capsule_splitter.repeat(remainder)
|
||||
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules)
|
||||
hrac, publisher_vk, bob_vk, remainder = splitter(data, return_remainder=True)
|
||||
ekfrag, remainder = EncryptedKeyFrag.take(remainder)
|
||||
capsule_bytes, remainder = BytestringSplitter(VariableLengthBytestring)(remainder, return_remainder=True)
|
||||
capsules = capsule_splitter.repeat(capsule_bytes)
|
||||
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules), remainder
|
||||
|
||||
|
||||
class ReencryptionResponse(Versioned):
|
||||
|
@ -572,7 +572,7 @@ class ReencryptionResponse(Versioned):
|
|||
|
||||
def _payload(self) -> bytes:
|
||||
"""Returns the unversioned bytes serialized representation of this instance."""
|
||||
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
|
||||
return bytes(self.signature) + bytes(VariableLengthBytestring(b''.join(bytes(cfrag) for cfrag in self.cfrags)))
|
||||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
|
@ -588,7 +588,8 @@ class ReencryptionResponse(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
signature, cfrags_bytes = signature_splitter(data, return_remainder=True)
|
||||
splitter = BytestringSplitter(signature_splitter, VariableLengthBytestring)
|
||||
signature, cfrags_bytes, remainder = splitter(data, return_remainder=True)
|
||||
|
||||
# We would never send a request with no capsules, so there should be cfrags.
|
||||
# The splitter would fail anyway, this just makes the error message more clear.
|
||||
|
@ -596,7 +597,7 @@ class ReencryptionResponse(Versioned):
|
|||
raise ValueError(f"{cls.__name__} contains no cfrags")
|
||||
|
||||
cfrags = cfrag_splitter.repeat(cfrags_bytes)
|
||||
return cls(cfrags, signature)
|
||||
return cls(cfrags, signature), remainder
|
||||
|
||||
def verify(self,
|
||||
capsules: Sequence[Capsule],
|
||||
|
@ -646,7 +647,7 @@ class RetrievalKit(Versioned):
|
|||
|
||||
def _payload(self) -> bytes:
|
||||
return (bytes(self.capsule) +
|
||||
b''.join(to_canonical_address(address) for address in self.queried_addresses))
|
||||
bytes(VariableLengthBytestring(b''.join(to_canonical_address(address) for address in self.queried_addresses))))
|
||||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
|
@ -662,12 +663,13 @@ class RetrievalKit(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
capsule, remainder = capsule_splitter(data, return_remainder=True)
|
||||
if remainder:
|
||||
addresses_as_bytes = checksum_address_splitter.repeat(remainder)
|
||||
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
|
||||
capsule, addresses_bytes, remainder = splitter(data, return_remainder=True)
|
||||
if addresses_bytes:
|
||||
addresses = checksum_address_splitter.repeat(addresses_bytes)
|
||||
else:
|
||||
addresses_as_bytes = ()
|
||||
return cls(capsule, set(to_checksum_address(address) for address in addresses_as_bytes))
|
||||
addresses = ()
|
||||
return cls(capsule, addresses), remainder
|
||||
|
||||
|
||||
class RevocationOrder(Versioned):
|
||||
|
@ -701,7 +703,7 @@ class RevocationOrder(Versioned):
|
|||
|
||||
@staticmethod
|
||||
def _signed_payload(ursula_address, encrypted_kfrag):
|
||||
return to_canonical_address(ursula_address) + bytes(VariableLengthBytestring(bytes(encrypted_kfrag)))
|
||||
return to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
|
||||
|
||||
def verify_signature(self, alice_verifying_key: PublicKey) -> bool:
|
||||
"""
|
||||
|
@ -726,21 +728,20 @@ class RevocationOrder(Versioned):
|
|||
return {}
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return self._signed_payload(self.ursula_address, self.encrypted_kfrag) + bytes(self.signature)
|
||||
return bytes(self.signature) + self._signed_payload(self.ursula_address, self.encrypted_kfrag)
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
splitter = BytestringSplitter(
|
||||
signature_splitter,
|
||||
checksum_address_splitter, # ursula canonical address
|
||||
VariableLengthBytestring, # EncryptedKeyFrag
|
||||
signature_splitter
|
||||
)
|
||||
ursula_canonical_address, ekfrag_bytes, signature = splitter(data)
|
||||
ekfrag = EncryptedKeyFrag.from_bytes(ekfrag_bytes)
|
||||
ursula_address = to_checksum_address(ursula_canonical_address)
|
||||
return cls(ursula_address=ursula_address,
|
||||
encrypted_kfrag=ekfrag,
|
||||
signature=signature)
|
||||
signature, ursula_address, remainder = splitter(data, return_remainder=True)
|
||||
ekfrag, remainder = EncryptedKeyFrag.take(remainder)
|
||||
obj = cls(ursula_address=ursula_address,
|
||||
encrypted_kfrag=ekfrag,
|
||||
signature=signature)
|
||||
return obj, remainder
|
||||
|
||||
|
||||
class NodeMetadataPayload(NamedTuple):
|
||||
|
@ -755,20 +756,19 @@ class NodeMetadataPayload(NamedTuple):
|
|||
host: str
|
||||
port: int
|
||||
|
||||
_splitter = BytestringKwargifier(
|
||||
dict,
|
||||
public_address=ETH_ADDRESS_BYTE_LENGTH,
|
||||
domain_bytes=VariableLengthBytestring,
|
||||
timestamp_epoch=(int, 4, {'byteorder': 'big'}),
|
||||
_splitter = BytestringSplitter(
|
||||
(bytes, ETH_ADDRESS_BYTE_LENGTH), # public_address
|
||||
VariableLengthBytestring, # domain_bytes
|
||||
(int, 4, {'byteorder': 'big'}), # timestamp_epoch
|
||||
|
||||
# FIXME: Fixed length doesn't work with federated. It was LENGTH_ECDSA_SIGNATURE_WITH_RECOVERY,
|
||||
decentralized_identity_evidence=VariableLengthBytestring,
|
||||
VariableLengthBytestring, # decentralized_identity_evidence
|
||||
|
||||
verifying_key=key_splitter,
|
||||
encrypting_key=key_splitter,
|
||||
certificate_bytes=VariableLengthBytestring,
|
||||
host_bytes=VariableLengthBytestring,
|
||||
port=(int, 2, {'byteorder': 'big'}),
|
||||
key_splitter, # verifying_key
|
||||
key_splitter, # encrypting_key
|
||||
VariableLengthBytestring, # certificate_bytes
|
||||
VariableLengthBytestring, # host_bytes
|
||||
(int, 2, {'byteorder': 'big'}), # port
|
||||
)
|
||||
|
||||
def __bytes__(self):
|
||||
|
@ -784,19 +784,40 @@ class NodeMetadataPayload(NamedTuple):
|
|||
))
|
||||
return as_bytes
|
||||
|
||||
@classmethod
|
||||
def take(cls, data):
|
||||
*fields, remainder = cls._splitter(data, return_remainder=True)
|
||||
|
||||
(public_address,
|
||||
domain,
|
||||
timestamp_epoch,
|
||||
decentralized_identity_evidence,
|
||||
verifying_key,
|
||||
encrypting_key,
|
||||
certificate_bytes,
|
||||
host,
|
||||
port,
|
||||
) = fields
|
||||
|
||||
obj = cls(public_address=public_address,
|
||||
domain=domain.decode('utf-8'),
|
||||
timestamp_epoch=timestamp_epoch,
|
||||
decentralized_identity_evidence=decentralized_identity_evidence,
|
||||
verifying_key=verifying_key,
|
||||
encrypting_key=encrypting_key,
|
||||
certificate_bytes=certificate_bytes,
|
||||
host=host.decode('utf-8'),
|
||||
port=port,
|
||||
)
|
||||
|
||||
return obj, remainder
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
result = cls._splitter(data)
|
||||
return cls(public_address=result['public_address'],
|
||||
domain=result['domain_bytes'].decode('utf-8'),
|
||||
timestamp_epoch=result['timestamp_epoch'],
|
||||
decentralized_identity_evidence=result['decentralized_identity_evidence'],
|
||||
verifying_key=result['verifying_key'],
|
||||
encrypting_key=result['encrypting_key'],
|
||||
certificate_bytes=result['certificate_bytes'],
|
||||
host=result['host_bytes'].decode('utf-8'),
|
||||
port=result['port'],
|
||||
)
|
||||
obj, remainder = cls.take(data)
|
||||
if remainder:
|
||||
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
|
||||
return obj
|
||||
|
||||
|
||||
class NodeMetadata(Versioned):
|
||||
|
@ -838,16 +859,16 @@ class NodeMetadata(Versioned):
|
|||
@classmethod
|
||||
def _from_bytes_current(cls, data: bytes):
|
||||
signature, remainder = signature_splitter(data, return_remainder=True)
|
||||
payload = NodeMetadataPayload.from_bytes(remainder)
|
||||
return cls(signature=signature, payload=payload)
|
||||
payload, remainder = NodeMetadataPayload.take(remainder)
|
||||
return cls(signature=signature, payload=payload), remainder
|
||||
|
||||
@classmethod
|
||||
def batch_from_bytes(cls, data: bytes):
|
||||
|
||||
node_splitter = BytestringSplitter(VariableLengthBytestring)
|
||||
nodes_vbytes = node_splitter.repeat(data)
|
||||
|
||||
return [cls.from_bytes(node_data) for node_data in nodes_vbytes]
|
||||
def _batch_from_bytes(cls, data: bytes):
|
||||
nodes = []
|
||||
while data:
|
||||
node, data = cls.take(data)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
||||
class MetadataRequest(Versioned):
|
||||
|
@ -876,29 +897,34 @@ class MetadataRequest(Versioned):
|
|||
|
||||
def _payload(self):
|
||||
if self.announce_nodes:
|
||||
nodes_bytes = bytes().join(bytes(VariableLengthBytestring(bytes(n))) for n in self.announce_nodes)
|
||||
nodes_bytes = b''.join(bytes(n) for n in self.announce_nodes)
|
||||
else:
|
||||
nodes_bytes = b''
|
||||
return bytes.fromhex(self.fleet_state_checksum) + nodes_bytes
|
||||
return bytes.fromhex(self.fleet_state_checksum) + bytes(VariableLengthBytestring(nodes_bytes))
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
fleet_state_checksum_bytes, nodes_bytes = cls._fleet_state_checksum_splitter(data, return_remainder=True)
|
||||
splitter = BytestringSplitter(
|
||||
(bytes, 32), # fleet state checksum
|
||||
VariableLengthBytestring,
|
||||
)
|
||||
fleet_state_checksum_bytes, nodes_bytes, remainder = splitter(data, return_remainder=True)
|
||||
if nodes_bytes:
|
||||
nodes = NodeMetadata.batch_from_bytes(nodes_bytes)
|
||||
nodes = NodeMetadata._batch_from_bytes(nodes_bytes)
|
||||
else:
|
||||
nodes = None
|
||||
return cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(),
|
||||
announce_nodes=nodes)
|
||||
obj = cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(),
|
||||
announce_nodes=nodes)
|
||||
return obj, remainder
|
||||
|
||||
|
||||
class MetadataResponse(Versioned):
|
||||
|
||||
_splitter = BytestringSplitter(
|
||||
signature_splitter,
|
||||
(int, 4, {'byteorder': 'big'}),
|
||||
VariableLengthBytestring,
|
||||
VariableLengthBytestring,
|
||||
signature_splitter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -918,7 +944,7 @@ class MetadataResponse(Versioned):
|
|||
@staticmethod
|
||||
def _signed_payload(timestamp_epoch, this_node, other_nodes):
|
||||
timestamp = timestamp_epoch.to_bytes(4, byteorder="big")
|
||||
nodes_payload = b''.join(bytes(VariableLengthBytestring(bytes(node))) for node in other_nodes) if other_nodes else b''
|
||||
nodes_payload = b''.join(bytes(node) for node in other_nodes) if other_nodes else b''
|
||||
return (
|
||||
timestamp +
|
||||
bytes(VariableLengthBytestring(bytes(this_node) if this_node else b'')) +
|
||||
|
@ -955,14 +981,15 @@ class MetadataResponse(Versioned):
|
|||
|
||||
def _payload(self):
|
||||
payload = self._signed_payload(self.timestamp_epoch, self.this_node, self.other_nodes)
|
||||
return payload + bytes(self.signature)
|
||||
return bytes(self.signature) + payload
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data: bytes):
|
||||
timestamp_epoch, maybe_this_node, maybe_other_nodes, signature = cls._splitter(data)
|
||||
signature, timestamp_epoch, maybe_this_node, maybe_other_nodes, remainder = cls._splitter(data, return_remainder=True)
|
||||
this_node = NodeMetadata.from_bytes(maybe_this_node) if maybe_this_node else None
|
||||
other_nodes = NodeMetadata.batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None
|
||||
return cls(signature=signature,
|
||||
timestamp_epoch=timestamp_epoch,
|
||||
this_node=this_node,
|
||||
other_nodes=other_nodes)
|
||||
other_nodes = NodeMetadata._batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None
|
||||
obj = cls(signature=signature,
|
||||
timestamp_epoch=timestamp_epoch,
|
||||
this_node=this_node,
|
||||
other_nodes=other_nodes)
|
||||
return obj, remainder
|
||||
|
|
|
@ -93,12 +93,24 @@ class Versioned(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
""""Public deserialization API"""
|
||||
def take(cls, data: bytes):
|
||||
"""
|
||||
Deserializes the object from the given bytestring
|
||||
and returns the object and the remainder of the bytestring.
|
||||
"""
|
||||
brand, version, payload = cls._parse_header(data)
|
||||
version = cls._resolve_version(version=version)
|
||||
handlers = cls._deserializers()
|
||||
return handlers[version](payload)
|
||||
obj, remainder = handlers[version](payload)
|
||||
return obj, remainder
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
""""Public deserialization API"""
|
||||
obj, remainder = cls.take(data)
|
||||
if remainder:
|
||||
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:
|
||||
|
|
|
@ -131,7 +131,7 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64})
|
||||
|
||||
assert "Could not convert input for tmap to a TreasureMap" in str(e)
|
||||
assert "Invalid treasure map contents." in str(e)
|
||||
assert "Can't split a message with more bytes than the original splittable." in str(e)
|
||||
|
||||
# a valid treasuremap
|
||||
decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map,
|
||||
|
|
|
@ -35,7 +35,7 @@ def _check_valid_version_tuple(version: Any, cls: Type):
|
|||
|
||||
class A(Versioned):
|
||||
|
||||
def __init__(self, x):
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
@classmethod
|
||||
|
@ -47,7 +47,7 @@ class A(Versioned):
|
|||
return 2, 1
|
||||
|
||||
def _payload(self) -> bytes:
|
||||
return bytes(self.x)
|
||||
return self.x.to_bytes(1, 'big')
|
||||
|
||||
@classmethod
|
||||
def _old_version_handlers(cls):
|
||||
|
@ -57,11 +57,16 @@ class A(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _from_bytes_v2_0(cls, data):
|
||||
return cls(int(data, 16)) # then we switched to the hexadecimal
|
||||
# v2.0 saved a 4 byte integer in hex format
|
||||
int_hex, remainder = data[:2], data[2:]
|
||||
int_bytes = bytes.fromhex(int_hex.decode())
|
||||
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
||||
|
||||
@classmethod
|
||||
def _from_bytes_current(cls, data):
|
||||
return cls(str(int(data, 16))) # but now we use a string representation for some reason
|
||||
# v2.1 saves a 4 byte integer as 4 bytes
|
||||
int_bytes, remainder = data[:1], data[1:]
|
||||
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
||||
|
||||
|
||||
def test_unique_branding():
|
||||
|
@ -206,13 +211,13 @@ def test_current_minor_version_handler_routing(mocker):
|
|||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_1_data = b'ABCD\x00\x02\x00\x0112'
|
||||
v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
|
||||
a = A.from_bytes(v2_1_data)
|
||||
assert a.x == '18'
|
||||
assert a.x == 18
|
||||
|
||||
# Current version was correctly routed to the v2.1 handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'12')
|
||||
current_spy.assert_called_with(b'\x12')
|
||||
assert not v2_0_spy.call_count
|
||||
|
||||
|
||||
|
@ -220,12 +225,12 @@ def test_future_minor_version_handler_routing(mocker):
|
|||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_2_data = b'ABCD\x00\x02\x02\x0112'
|
||||
v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
|
||||
a = A.from_bytes(v2_2_data)
|
||||
assert a.x == '18'
|
||||
assert a.x == 18
|
||||
|
||||
# Future minor version was correctly routed to
|
||||
# the current minor version handler.
|
||||
assert current_spy.call_count == 1
|
||||
current_spy.assert_called_with(b'12')
|
||||
current_spy.assert_called_with(b'\x12')
|
||||
assert not v2_0_spy.call_count
|
||||
|
|
Loading…
Reference in New Issue