Normalize the usage of VariableLengthBytestrings

Allow for every Versioned type to be able to deserialize itself from the bytestring.
pull/2809/head
Bogdan Opanchuk 2021-10-30 12:10:59 -07:00
parent 1114891680
commit 944d3373e7
4 changed files with 164 additions and 120 deletions

View File

@ -54,7 +54,7 @@ signature_splitter = BytestringSplitter((Signature, Signature.serialized_size())
capsule_splitter = BytestringSplitter((Capsule, Capsule.serialized_size()))
cfrag_splitter = BytestringSplitter((CapsuleFrag, CapsuleFrag.serialized_size()))
kfrag_splitter = BytestringSplitter((KeyFrag, KeyFrag.serialized_size()))
checksum_address_splitter = BytestringSplitter((bytes, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant?
checksum_address_splitter = BytestringSplitter((to_checksum_address, ETH_ADDRESS_BYTE_LENGTH)) # TODO: is there a pre-defined constant?
class MessageKit(Versioned):
@ -106,8 +106,8 @@ class MessageKit(Versioned):
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, ciphertext = splitter(data)
return cls(capsule, ciphertext)
capsule, ciphertext, remainder = splitter(data, return_remainder=True)
return cls(capsule, ciphertext), remainder
class HRAC:
@ -199,8 +199,8 @@ class AuthorizedKeyFrag(Versioned):
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(signature_splitter, kfrag_splitter)
signature, kfrag = splitter(data)
return cls(signature, kfrag)
signature, kfrag, remainder = splitter(data, return_remainder=True)
return cls(signature, kfrag), remainder
def verify(self,
hrac: HRAC,
@ -223,6 +223,8 @@ class AuthorizedKeyFrag(Versioned):
class EncryptedKeyFrag:
_splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
@classmethod
def author(cls, recipient_key: PublicKey, authorized_kfrag: AuthorizedKeyFrag):
# TODO: using Umbral for encryption to avoid introducing more crypto primitives.
@ -242,11 +244,13 @@ class EncryptedKeyFrag:
def __bytes__(self):
return bytes(self.capsule) + bytes(VariableLengthBytestring(self.ciphertext))
# Ideally we would define a splitter here that would deserialize into an EKF,
# but due to BSS limitations it cannot be nested (since it doesn't have a definite size).
# So we have to define this helper method and use that instead.
@classmethod
def from_bytes(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, ciphertext = splitter(data)
return cls(capsule, ciphertext)
def take(cls, data):
capsule, ciphertext, remainder = cls._splitter(data, return_remainder=True)
return cls(capsule, ciphertext), remainder
def __eq__(self, other):
return self.capsule == other.capsule and self.ciphertext == other.ciphertext
@ -331,11 +335,16 @@ class TreasureMap(Versioned):
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
assigned_kfrags = b''.join(
to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
for ursula_address, encrypted_kfrag in self.destinations.items()
)
return (self.threshold.to_bytes(1, "big") +
bytes(self.hrac) +
bytes(self.policy_encrypting_key) +
bytes(self.publisher_verifying_key) +
self._nodes_as_bytes())
bytes(VariableLengthBytestring(assigned_kfrags)))
@classmethod
def _from_bytes_current(cls, data):
@ -345,20 +354,18 @@ class TreasureMap(Versioned):
hrac_splitter,
key_splitter,
key_splitter,
VariableLengthBytestring,
)
ursula_and_kfrag_payload_splitter = BytestringSplitter(
(to_checksum_address, ETH_ADDRESS_BYTE_LENGTH),
(EncryptedKeyFrag, VariableLengthBytestring),
)
threshold, hrac, policy_encrypting_key, publisher_verifying_key, assigned_kfrags_bytes, remainder = main_splitter(data, return_remainder=True)
try:
threshold, hrac, policy_encrypting_key, publisher_verifying_key, remainder = main_splitter(data, return_remainder=True)
ursula_and_kfrags = ursula_and_kfrag_payload_splitter.repeat(remainder)
except BytestringSplittingError as e:
raise ValueError('Invalid treasure map contents.') from e
destinations = {u: k for u, k in ursula_and_kfrags}
return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations)
destinations = {}
while assigned_kfrags_bytes:
ursula_address, assigned_kfrags_bytes = checksum_address_splitter(assigned_kfrags_bytes, return_remainder=True)
ekf, assigned_kfrags_bytes = EncryptedKeyFrag.take(assigned_kfrags_bytes)
destinations[ursula_address] = ekf
return cls(threshold, hrac, policy_encrypting_key, publisher_verifying_key, destinations), remainder
def encrypt(self,
signer: Signer,
@ -368,14 +375,6 @@ class TreasureMap(Versioned):
recipient_key=recipient_key,
treasure_map=self)
def _nodes_as_bytes(self) -> bytes:
nodes_as_bytes = b""
for ursula_address, encrypted_kfrag in self.destinations.items():
node_id = to_canonical_address(ursula_address)
kfrag = bytes(VariableLengthBytestring(bytes(encrypted_kfrag)))
nodes_as_bytes += (node_id + kfrag)
return nodes_as_bytes
class AuthorizedTreasureMap(Versioned):
@ -408,13 +407,13 @@ class AuthorizedTreasureMap(Versioned):
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return (bytes(self.signature) +
VariableLengthBytestring(bytes(self.treasure_map)))
bytes(self.treasure_map))
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(signature_splitter, (TreasureMap, VariableLengthBytestring))
signature, treasure_map = splitter(data)
return cls(signature, treasure_map)
signature, remainder = signature_splitter(data, return_remainder=True)
treasure_map, remainder = TreasureMap.take(remainder)
return cls(signature, treasure_map), remainder
def verify(self, recipient_key: PublicKey, publisher_verifying_key: PublicKey) -> TreasureMap:
payload = bytes(recipient_key) + bytes(self.treasure_map)
@ -472,8 +471,8 @@ class EncryptedTreasureMap(Versioned):
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, ciphertext = splitter(data)
return cls(capsule, ciphertext)
capsule, ciphertext, remainder = splitter(data, return_remainder=True)
return cls(capsule, ciphertext), remainder
def __eq__(self, other):
return bytes(self) == bytes(other)
@ -518,8 +517,8 @@ class ReencryptionRequest(Versioned):
return (bytes(self.hrac) +
bytes(self.publisher_verifying_key) +
bytes(self.bob_verifying_key) +
VariableLengthBytestring(bytes(self.encrypted_kfrag)) +
b''.join(bytes(capsule) for capsule in self.capsules)
bytes(self.encrypted_kfrag) +
bytes(VariableLengthBytestring(b''.join(bytes(capsule) for capsule in self.capsules)))
)
@classmethod
@ -538,12 +537,13 @@ class ReencryptionRequest(Versioned):
def _from_bytes_current(cls, data):
splitter = (hrac_splitter +
key_splitter +
key_splitter +
BytestringSplitter((EncryptedKeyFrag, VariableLengthBytestring)))
key_splitter)
hrac, publisher_vk, bob_vk, ekfrag, remainder = splitter(data, return_remainder=True)
capsules = capsule_splitter.repeat(remainder)
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules)
hrac, publisher_vk, bob_vk, remainder = splitter(data, return_remainder=True)
ekfrag, remainder = EncryptedKeyFrag.take(remainder)
capsule_bytes, remainder = BytestringSplitter(VariableLengthBytestring)(remainder, return_remainder=True)
capsules = capsule_splitter.repeat(capsule_bytes)
return cls(hrac, publisher_vk, bob_vk, ekfrag, capsules), remainder
class ReencryptionResponse(Versioned):
@ -572,7 +572,7 @@ class ReencryptionResponse(Versioned):
def _payload(self) -> bytes:
"""Returns the unversioned bytes serialized representation of this instance."""
return bytes(self.signature) + b''.join(bytes(cfrag) for cfrag in self.cfrags)
return bytes(self.signature) + bytes(VariableLengthBytestring(b''.join(bytes(cfrag) for cfrag in self.cfrags)))
@classmethod
def _brand(cls) -> bytes:
@ -588,7 +588,8 @@ class ReencryptionResponse(Versioned):
@classmethod
def _from_bytes_current(cls, data):
signature, cfrags_bytes = signature_splitter(data, return_remainder=True)
splitter = BytestringSplitter(signature_splitter, VariableLengthBytestring)
signature, cfrags_bytes, remainder = splitter(data, return_remainder=True)
# We would never send a request with no capsules, so there should be cfrags.
# The splitter would fail anyway, this just makes the error message more clear.
@ -596,7 +597,7 @@ class ReencryptionResponse(Versioned):
raise ValueError(f"{cls.__name__} contains no cfrags")
cfrags = cfrag_splitter.repeat(cfrags_bytes)
return cls(cfrags, signature)
return cls(cfrags, signature), remainder
def verify(self,
capsules: Sequence[Capsule],
@ -646,7 +647,7 @@ class RetrievalKit(Versioned):
def _payload(self) -> bytes:
return (bytes(self.capsule) +
b''.join(to_canonical_address(address) for address in self.queried_addresses))
bytes(VariableLengthBytestring(b''.join(to_canonical_address(address) for address in self.queried_addresses))))
@classmethod
def _brand(cls) -> bytes:
@ -662,12 +663,13 @@ class RetrievalKit(Versioned):
@classmethod
def _from_bytes_current(cls, data):
capsule, remainder = capsule_splitter(data, return_remainder=True)
if remainder:
addresses_as_bytes = checksum_address_splitter.repeat(remainder)
splitter = BytestringSplitter(capsule_splitter, VariableLengthBytestring)
capsule, addresses_bytes, remainder = splitter(data, return_remainder=True)
if addresses_bytes:
addresses = checksum_address_splitter.repeat(addresses_bytes)
else:
addresses_as_bytes = ()
return cls(capsule, set(to_checksum_address(address) for address in addresses_as_bytes))
addresses = ()
return cls(capsule, addresses), remainder
class RevocationOrder(Versioned):
@ -701,7 +703,7 @@ class RevocationOrder(Versioned):
@staticmethod
def _signed_payload(ursula_address, encrypted_kfrag):
return to_canonical_address(ursula_address) + bytes(VariableLengthBytestring(bytes(encrypted_kfrag)))
return to_canonical_address(ursula_address) + bytes(encrypted_kfrag)
def verify_signature(self, alice_verifying_key: PublicKey) -> bool:
"""
@ -726,21 +728,20 @@ class RevocationOrder(Versioned):
return {}
def _payload(self) -> bytes:
return self._signed_payload(self.ursula_address, self.encrypted_kfrag) + bytes(self.signature)
return bytes(self.signature) + self._signed_payload(self.ursula_address, self.encrypted_kfrag)
@classmethod
def _from_bytes_current(cls, data):
splitter = BytestringSplitter(
signature_splitter,
checksum_address_splitter, # ursula canonical address
VariableLengthBytestring, # EncryptedKeyFrag
signature_splitter
)
ursula_canonical_address, ekfrag_bytes, signature = splitter(data)
ekfrag = EncryptedKeyFrag.from_bytes(ekfrag_bytes)
ursula_address = to_checksum_address(ursula_canonical_address)
return cls(ursula_address=ursula_address,
encrypted_kfrag=ekfrag,
signature=signature)
signature, ursula_address, remainder = splitter(data, return_remainder=True)
ekfrag, remainder = EncryptedKeyFrag.take(remainder)
obj = cls(ursula_address=ursula_address,
encrypted_kfrag=ekfrag,
signature=signature)
return obj, remainder
class NodeMetadataPayload(NamedTuple):
@ -755,20 +756,19 @@ class NodeMetadataPayload(NamedTuple):
host: str
port: int
_splitter = BytestringKwargifier(
dict,
public_address=ETH_ADDRESS_BYTE_LENGTH,
domain_bytes=VariableLengthBytestring,
timestamp_epoch=(int, 4, {'byteorder': 'big'}),
_splitter = BytestringSplitter(
(bytes, ETH_ADDRESS_BYTE_LENGTH), # public_address
VariableLengthBytestring, # domain_bytes
(int, 4, {'byteorder': 'big'}), # timestamp_epoch
# FIXME: Fixed length doesn't work with federated. It was LENGTH_ECDSA_SIGNATURE_WITH_RECOVERY,
decentralized_identity_evidence=VariableLengthBytestring,
VariableLengthBytestring, # decentralized_identity_evidence
verifying_key=key_splitter,
encrypting_key=key_splitter,
certificate_bytes=VariableLengthBytestring,
host_bytes=VariableLengthBytestring,
port=(int, 2, {'byteorder': 'big'}),
key_splitter, # verifying_key
key_splitter, # encrypting_key
VariableLengthBytestring, # certificate_bytes
VariableLengthBytestring, # host_bytes
(int, 2, {'byteorder': 'big'}), # port
)
def __bytes__(self):
@ -784,19 +784,40 @@ class NodeMetadataPayload(NamedTuple):
))
return as_bytes
@classmethod
def take(cls, data):
*fields, remainder = cls._splitter(data, return_remainder=True)
(public_address,
domain,
timestamp_epoch,
decentralized_identity_evidence,
verifying_key,
encrypting_key,
certificate_bytes,
host,
port,
) = fields
obj = cls(public_address=public_address,
domain=domain.decode('utf-8'),
timestamp_epoch=timestamp_epoch,
decentralized_identity_evidence=decentralized_identity_evidence,
verifying_key=verifying_key,
encrypting_key=encrypting_key,
certificate_bytes=certificate_bytes,
host=host.decode('utf-8'),
port=port,
)
return obj, remainder
@classmethod
def from_bytes(cls, data):
result = cls._splitter(data)
return cls(public_address=result['public_address'],
domain=result['domain_bytes'].decode('utf-8'),
timestamp_epoch=result['timestamp_epoch'],
decentralized_identity_evidence=result['decentralized_identity_evidence'],
verifying_key=result['verifying_key'],
encrypting_key=result['encrypting_key'],
certificate_bytes=result['certificate_bytes'],
host=result['host_bytes'].decode('utf-8'),
port=result['port'],
)
obj, remainder = cls.take(data)
if remainder:
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
return obj
class NodeMetadata(Versioned):
@ -838,16 +859,16 @@ class NodeMetadata(Versioned):
@classmethod
def _from_bytes_current(cls, data: bytes):
signature, remainder = signature_splitter(data, return_remainder=True)
payload = NodeMetadataPayload.from_bytes(remainder)
return cls(signature=signature, payload=payload)
payload, remainder = NodeMetadataPayload.take(remainder)
return cls(signature=signature, payload=payload), remainder
@classmethod
def batch_from_bytes(cls, data: bytes):
node_splitter = BytestringSplitter(VariableLengthBytestring)
nodes_vbytes = node_splitter.repeat(data)
return [cls.from_bytes(node_data) for node_data in nodes_vbytes]
def _batch_from_bytes(cls, data: bytes):
nodes = []
while data:
node, data = cls.take(data)
nodes.append(node)
return nodes
class MetadataRequest(Versioned):
@ -876,29 +897,34 @@ class MetadataRequest(Versioned):
def _payload(self):
if self.announce_nodes:
nodes_bytes = bytes().join(bytes(VariableLengthBytestring(bytes(n))) for n in self.announce_nodes)
nodes_bytes = b''.join(bytes(n) for n in self.announce_nodes)
else:
nodes_bytes = b''
return bytes.fromhex(self.fleet_state_checksum) + nodes_bytes
return bytes.fromhex(self.fleet_state_checksum) + bytes(VariableLengthBytestring(nodes_bytes))
@classmethod
def _from_bytes_current(cls, data):
fleet_state_checksum_bytes, nodes_bytes = cls._fleet_state_checksum_splitter(data, return_remainder=True)
splitter = BytestringSplitter(
(bytes, 32), # fleet state checksum
VariableLengthBytestring,
)
fleet_state_checksum_bytes, nodes_bytes, remainder = splitter(data, return_remainder=True)
if nodes_bytes:
nodes = NodeMetadata.batch_from_bytes(nodes_bytes)
nodes = NodeMetadata._batch_from_bytes(nodes_bytes)
else:
nodes = None
return cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(),
announce_nodes=nodes)
obj = cls(fleet_state_checksum=fleet_state_checksum_bytes.hex(),
announce_nodes=nodes)
return obj, remainder
class MetadataResponse(Versioned):
_splitter = BytestringSplitter(
signature_splitter,
(int, 4, {'byteorder': 'big'}),
VariableLengthBytestring,
VariableLengthBytestring,
signature_splitter,
)
@classmethod
@ -918,7 +944,7 @@ class MetadataResponse(Versioned):
@staticmethod
def _signed_payload(timestamp_epoch, this_node, other_nodes):
timestamp = timestamp_epoch.to_bytes(4, byteorder="big")
nodes_payload = b''.join(bytes(VariableLengthBytestring(bytes(node))) for node in other_nodes) if other_nodes else b''
nodes_payload = b''.join(bytes(node) for node in other_nodes) if other_nodes else b''
return (
timestamp +
bytes(VariableLengthBytestring(bytes(this_node) if this_node else b'')) +
@ -955,14 +981,15 @@ class MetadataResponse(Versioned):
def _payload(self):
payload = self._signed_payload(self.timestamp_epoch, self.this_node, self.other_nodes)
return payload + bytes(self.signature)
return bytes(self.signature) + payload
@classmethod
def _from_bytes_current(cls, data: bytes):
timestamp_epoch, maybe_this_node, maybe_other_nodes, signature = cls._splitter(data)
signature, timestamp_epoch, maybe_this_node, maybe_other_nodes, remainder = cls._splitter(data, return_remainder=True)
this_node = NodeMetadata.from_bytes(maybe_this_node) if maybe_this_node else None
other_nodes = NodeMetadata.batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None
return cls(signature=signature,
timestamp_epoch=timestamp_epoch,
this_node=this_node,
other_nodes=other_nodes)
other_nodes = NodeMetadata._batch_from_bytes(maybe_other_nodes) if maybe_other_nodes else None
obj = cls(signature=signature,
timestamp_epoch=timestamp_epoch,
this_node=this_node,
other_nodes=other_nodes)
return obj, remainder

View File

@ -93,12 +93,24 @@ class Versioned(ABC):
raise NotImplementedError
@classmethod
def from_bytes(cls, data: bytes):
""""Public deserialization API"""
def take(cls, data: bytes):
"""
Deserializes the object from the given bytestring
and returns the object and the remainder of the bytestring.
"""
brand, version, payload = cls._parse_header(data)
version = cls._resolve_version(version=version)
handlers = cls._deserializers()
return handlers[version](payload)
obj, remainder = handlers[version](payload)
return obj, remainder
@classmethod
def from_bytes(cls, data: bytes):
""""Public deserialization API"""
obj, remainder = cls.take(data)
if remainder:
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
return obj
@classmethod
def _resolve_version(cls, version: Tuple[int, int]) -> Tuple[int, int]:

View File

@ -131,7 +131,7 @@ def test_treasure_map_validation(enacted_federated_policy,
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64})
assert "Could not convert input for tmap to a TreasureMap" in str(e)
assert "Invalid treasure map contents." in str(e)
assert "Can't split a message with more bytes than the original splittable." in str(e)
# a valid treasuremap
decrypted_treasure_map = federated_bob._decrypt_treasure_map(enacted_federated_policy.treasure_map,

View File

@ -35,7 +35,7 @@ def _check_valid_version_tuple(version: Any, cls: Type):
class A(Versioned):
def __init__(self, x):
def __init__(self, x: int):
self.x = x
@classmethod
@ -47,7 +47,7 @@ class A(Versioned):
return 2, 1
def _payload(self) -> bytes:
return bytes(self.x)
return self.x.to_bytes(1, 'big')
@classmethod
def _old_version_handlers(cls):
@ -57,11 +57,16 @@ class A(Versioned):
@classmethod
def _from_bytes_v2_0(cls, data):
return cls(int(data, 16)) # then we switched to the hexadecimal
# v2.0 saved a 4 byte integer in hex format
int_hex, remainder = data[:2], data[2:]
int_bytes = bytes.fromhex(int_hex.decode())
return cls(int.from_bytes(int_bytes, 'big')), remainder
@classmethod
def _from_bytes_current(cls, data):
return cls(str(int(data, 16))) # but now we use a string representation for some reason
# v2.1 saves a 4 byte integer as 4 bytes
int_bytes, remainder = data[:1], data[1:]
return cls(int.from_bytes(int_bytes, 'big')), remainder
def test_unique_branding():
@ -206,13 +211,13 @@ def test_current_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_1_data = b'ABCD\x00\x02\x00\x0112'
v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
a = A.from_bytes(v2_1_data)
assert a.x == '18'
assert a.x == 18
# Current version was correctly routed to the v2.1 handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'12')
current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count
@ -220,12 +225,12 @@ def test_future_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_2_data = b'ABCD\x00\x02\x02\x0112'
v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
a = A.from_bytes(v2_2_data)
assert a.x == '18'
assert a.x == 18
# Future minor version was correctly routed to
# the current minor version handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'12')
current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count