Expose serialized sizes of objects and overhaul Serializable accordingly

We assume here that all our objects have constant serialized size.
pull/270/head
Bogdan Opanchuk 2021-06-02 23:46:26 -07:00
parent 16def46564
commit b8175a3247
10 changed files with 193 additions and 125 deletions

View File

@ -2,7 +2,7 @@ import re
import pytest
from umbral.serializable import Serializable, serialize_bool, take_bool
from umbral.serializable import Serializable, bool_bytes, bool_from_exact_bytes
class A(Serializable):
@ -12,12 +12,15 @@ class A(Serializable):
self.val = val
@classmethod
def __take__(cls, data):
val_bytes, data = cls.__take_bytes__(data, 4)
return cls(int.from_bytes(val_bytes, byteorder='big')), data
def serialized_size(cls):
return 4
@classmethod
def _from_exact_bytes(cls, data):
return cls(int.from_bytes(data, byteorder='big'))
def __bytes__(self):
return self.val.to_bytes(4, byteorder='big')
return self.val.to_bytes(self.serialized_size(), byteorder='big')
def __eq__(self, other):
return isinstance(other, A) and self.val == other.val
@ -30,12 +33,15 @@ class B(Serializable):
self.val = val
@classmethod
def __take__(cls, data):
val_bytes, data = cls.__take_bytes__(data, 2)
return cls(int.from_bytes(val_bytes, byteorder='big')), data
def serialized_size(cls):
return 2
@classmethod
def _from_exact_bytes(cls, data):
return cls(int.from_bytes(data, byteorder='big'))
def __bytes__(self):
return self.val.to_bytes(2, byteorder='big')
return self.val.to_bytes(self.serialized_size(), byteorder='big')
def __eq__(self, other):
return isinstance(other, B) and self.val == other.val
@ -48,9 +54,13 @@ class C(Serializable):
self.b = b
@classmethod
def __take__(cls, data):
components, data = cls.__take_types__(data, A, B)
return cls(*components), data
def serialized_size(cls):
return A.serialized_size() + B.serialized_size()
@classmethod
def _from_exact_bytes(cls, data):
components = cls._split(data, A, B)
return cls(*components)
def __bytes__(self):
return bytes(self.a) + bytes(self.b)
@ -71,7 +81,7 @@ def test_too_many_bytes():
a = A(2**32 - 123)
b = B(2**16 - 456)
c = C(a, b)
with pytest.raises(ValueError, match="1 bytes remaining after deserializing"):
with pytest.raises(ValueError, match="Expected 6 bytes, got 7"):
C.from_bytes(bytes(c) + b'\x00')
@ -80,13 +90,22 @@ def test_not_enough_bytes():
b = B(2**16 - 456)
c = C(a, b)
# Will happen on deserialization of B - 1 byte missing
with pytest.raises(ValueError, match="cannot take 2 bytes from a bytestring of size 1"):
with pytest.raises(ValueError, match="Expected 6 bytes, got 5"):
C.from_bytes(bytes(c)[:-1])
def test_serialize_bool():
assert take_bool(serialize_bool(True) + b'1234') == (True, b'1234')
assert take_bool(serialize_bool(False) + b'12') == (False, b'12')
def test_bool_bytes():
assert bool_from_exact_bytes(bool_bytes(True)) == True
assert bool_from_exact_bytes(bool_bytes(False)) == False
error_msg = re.escape("Incorrectly serialized boolean; expected b'\\x00' or b'\\x01', got b'z'")
with pytest.raises(ValueError, match=error_msg):
take_bool(b'z1234')
bool_from_exact_bytes(b'z')
def test_split_bool():
a = A(2**32 - 123)
b = True
data = bytes(a) + bool_bytes(b)
a_back, b_back = Serializable._split(data, A, bool)
assert a_back == a
assert b_back == b

View File

@ -29,17 +29,21 @@ class Capsule(Serializable):
self.point_v = point_v
self.signature = signature
@classmethod
def __take__(cls, data: bytes) -> Tuple['Capsule', bytes]:
(e, v, sig), data = cls.__take_types__(data, CurvePoint, CurvePoint, CurveScalar)
_COMPONENT_TYPES = CurvePoint, CurvePoint, CurveScalar
_SERIALIZED_SIZE = sum(tp.serialized_size() for tp in _COMPONENT_TYPES)
capsule = cls(e, v, sig)
@classmethod
def serialized_size(cls):
return cls._SERIALIZED_SIZE
@classmethod
def _from_exact_bytes(cls, data: bytes):
capsule = cls(*cls._split(data, *cls._COMPONENT_TYPES))
if not capsule._verify():
raise GenericError("Capsule self-verification failed. Serialized data may be damaged.")
return capsule
return capsule, data
def __bytes__(self) -> bytes:
def __bytes__(self):
return bytes(self.point_e) + bytes(self.point_v) + bytes(self.signature)
@classmethod

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Type
from .capsule import Capsule
from .curve_point import CurvePoint
@ -34,23 +34,23 @@ class CapsuleFragProof(Serializable):
return (self.point_e2, self.point_v2, self.kfrag_commitment,
self.kfrag_pok, self.signature, self.kfrag_signature)
_COMPONENT_TYPES: Tuple[Type[Serializable], ...] = (
CurvePoint, CurvePoint, CurvePoint, CurvePoint, CurveScalar, Signature)
_SERIALIZED_SIZE = sum(tp.serialized_size() for tp in _COMPONENT_TYPES)
def __eq__(self, other):
return self._components() == other._components()
@classmethod
def __take__(cls, data):
types = [CurvePoint, CurvePoint, CurvePoint, CurvePoint, CurveScalar, Signature]
components, data = cls.__take_types__(data, *types)
return cls(*components), data
def serialized_size(cls):
return cls._SERIALIZED_SIZE
@classmethod
def _from_exact_bytes(cls, data):
return cls(*cls._split(data, *cls._COMPONENT_TYPES))
def __bytes__(self):
return (bytes(self.point_e2) +
bytes(self.point_v2) +
bytes(self.kfrag_commitment) +
bytes(self.kfrag_pok) +
bytes(self.signature) +
bytes(self.kfrag_signature)
)
return b''.join(bytes(comp) for comp in self._components())
@classmethod
def from_kfrag_and_cfrag(cls,
@ -117,6 +117,10 @@ class CapsuleFrag(Serializable):
def _components(self):
return (self.point_e1, self.point_v1, self.kfrag_id, self.precursor, self.proof)
_COMPONENT_TYPES: Tuple[Type[Serializable], ...] = (
CurvePoint, CurvePoint, KeyFragID, CurvePoint, CapsuleFragProof)
_SERIALIZED_SIZE = sum(tp.serialized_size() for tp in _COMPONENT_TYPES)
def __eq__(self, other):
return self._components() == other._components()
@ -127,17 +131,15 @@ class CapsuleFrag(Serializable):
return f"{self.__class__.__name__}:{bytes(self).hex()[:16]}"
@classmethod
def __take__(cls, data):
types = CurvePoint, CurvePoint, KeyFragID, CurvePoint, CapsuleFragProof
components, data = cls.__take_types__(data, *types)
return cls(*components), data
def serialized_size(cls):
return cls._SERIALIZED_SIZE
@classmethod
def _from_exact_bytes(cls, data):
return cls(*cls._split(data, *cls._COMPONENT_TYPES))
def __bytes__(self):
return (bytes(self.point_e1) +
bytes(self.point_v1) +
bytes(self.kfrag_id) +
bytes(self.precursor) +
bytes(self.proof))
return b''.join(bytes(comp) for comp in self._components())
@classmethod
def reencrypted(cls, capsule: Capsule, kfrag: KeyFrag) -> 'CapsuleFrag':

View File

@ -34,14 +34,15 @@ class CurvePoint(Serializable):
return openssl.point_to_affine_coords(CURVE, self._backend_point)
@classmethod
def __take__(cls, data: bytes) -> Tuple['CurvePoint', bytes]:
def serialized_size(cls):
return CURVE.field_element_size + 1 # compressed point size
@classmethod
def _from_exact_bytes(cls, data: bytes):
"""
Returns a CurvePoint object from the given byte data on the curve provided.
"""
size = CURVE.field_element_size + 1 # compressed point size
point_data, data = cls.__take_bytes__(data, size)
point = openssl.point_from_bytes(CURVE, point_data)
return cls(point), data
return cls(openssl.point_from_bytes(CURVE, data))
def __bytes__(self) -> bytes:
"""

View File

@ -45,10 +45,12 @@ class CurveScalar(Serializable):
return cls(openssl.bn_from_bytes(digest.finalize(), apply_modulus=CURVE.bn_order))
@classmethod
def __take__(cls, data: bytes) -> Tuple['CurveScalar', bytes]:
scalar_data, data = cls.__take_bytes__(data, CURVE.scalar_size)
bignum = openssl.bn_from_bytes(scalar_data, check_modulus=CURVE.bn_order)
return cls(bignum), data
def serialized_size(cls):
return CURVE.scalar_size
@classmethod
def _from_exact_bytes(cls, data: bytes):
return cls(openssl.bn_from_bytes(data, check_modulus=CURVE.bn_order))
def __bytes__(self) -> bytes:
"""

View File

@ -6,7 +6,7 @@ from .openssl import backend, ErrorInvalidCompressedPoint
from .curve import CURVE
from .curve_scalar import CurveScalar
from .curve_point import CurvePoint
from .serializable import Serializable, serialize_bool
from .serializable import Serializable, bool_bytes
if TYPE_CHECKING: # pragma: no cover
from .key_frag import KeyFragID
@ -79,14 +79,14 @@ def kfrag_signature_message(kfrag_id: 'KeyFragID',
# Have to convert to bytes manually because `mypy` is not smart enough to resolve types.
delegating_part = ([serialize_bool(True), bytes(maybe_delegating_pk)]
delegating_part = ([bool_bytes(True), bytes(maybe_delegating_pk)]
if maybe_delegating_pk
else [serialize_bool(False)])
else [bool_bytes(False)])
cast(List[Serializable], delegating_part)
receiving_part = ([serialize_bool(True), bytes(maybe_receiving_pk)]
receiving_part = ([bool_bytes(True), bytes(maybe_receiving_pk)]
if maybe_receiving_pk
else [serialize_bool(False)])
else [bool_bytes(False)])
components = ([bytes(kfrag_id), bytes(commitment), bytes(precursor)] +
delegating_part +

View File

@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import List, Optional, Tuple, Type
from .curve_point import CurvePoint
from .curve_scalar import CurveScalar
@ -7,7 +7,7 @@ from .errors import VerificationError
from .hashing import hash_to_shared_secret, kfrag_signature_message, hash_to_polynomial_arg
from .keys import PublicKey, SecretKey
from .params import PARAMETERS
from .serializable import Serializable, serialize_bool, take_bool
from .serializable import Serializable, bool_bytes, bool_serialized_size
from .signing import Signature, Signer
@ -26,9 +26,12 @@ class KeyFragID(Serializable):
return cls(os.urandom(cls.__SIZE))
@classmethod
def __take__(cls, data):
id_, data = cls.__take_bytes__(data, cls.__SIZE)
return cls(id_), data
def serialized_size(cls):
return cls.__SIZE
@classmethod
def _from_exact_bytes(cls, data):
return cls(data)
def __bytes__(self):
return self._id
@ -102,22 +105,25 @@ class KeyFragProof(Serializable):
def __eq__(self, other):
return self._components() == other._components()
@classmethod
def __take__(cls, data):
types = [CurvePoint, Signature, Signature]
(commitment, sig_proxy, sig_bob), data = cls.__take_types__(data, *types)
delegating_key_signed, data = take_bool(data)
receiving_key_signed, data = take_bool(data)
_SERIALIZED_SIZE = (CurvePoint.serialized_size() +
Signature.serialized_size() * 2 +
bool_serialized_size() * 2)
obj = cls(commitment, sig_proxy, sig_bob, delegating_key_signed, receiving_key_signed)
return obj, data
@classmethod
def serialized_size(cls):
return cls._SERIALIZED_SIZE
@classmethod
def _from_exact_bytes(cls, data):
types = [CurvePoint, Signature, Signature, bool, bool]
return cls(*cls._split(data, *types))
def __bytes__(self):
return (bytes(self.commitment) +
bytes(self.signature_for_proxy) +
bytes(self.signature_for_receiver) +
serialize_bool(self.delegating_key_signed) +
serialize_bool(self.receiving_key_signed)
bool_bytes(self.delegating_key_signed) +
bool_bytes(self.receiving_key_signed)
)
@ -144,18 +150,24 @@ class KeyFrag(Serializable):
self.precursor = precursor
self.proof = proof
@classmethod
def __take__(cls, data):
types = [KeyFragID, CurveScalar, CurvePoint, KeyFragProof]
components, data = cls.__take_types__(data, *types)
return cls(*components), data
def __bytes__(self):
return bytes(self.id) + bytes(self.key) + bytes(self.precursor) + bytes(self.proof)
def _components(self):
return self.id, self.key, self.precursor, self.proof
_COMPONENT_TYPES: Tuple[Type[Serializable], ...] = (
KeyFragID, CurveScalar, CurvePoint, KeyFragProof)
_SERIALIZED_SIZE = sum(tp.serialized_size() for tp in _COMPONENT_TYPES)
@classmethod
def serialized_size(cls):
return cls._SERIALIZED_SIZE
@classmethod
def _from_exact_bytes(cls, data):
return cls(*cls._split(data, *cls._COMPONENT_TYPES))
def __bytes__(self):
return b''.join(bytes(comp) for comp in self._components())
def __eq__(self, other):
return self._components() == other._components()

View File

@ -43,9 +43,12 @@ class SecretKey(Serializable):
return self._scalar_key
@classmethod
def __take__(cls, data: bytes) -> Tuple['SecretKey', bytes]:
(scalar_key,), data = cls.__take_types__(data, CurveScalar)
return cls(scalar_key), data
def serialized_size(cls):
return CurveScalar.serialized_size()
@classmethod
def _from_exact_bytes(cls, data: bytes):
return cls(CurveScalar._from_exact_bytes(data))
def __bytes__(self) -> bytes:
return bytes(self._scalar_key)
@ -70,9 +73,12 @@ class PublicKey(Serializable):
return cls(sk._public_key_point)
@classmethod
def __take__(cls, data: bytes) -> Tuple['PublicKey', bytes]:
(point_key,), data = cls.__take_types__(data, CurvePoint)
return cls(point_key), data
def serialized_size(cls):
return CurvePoint.serialized_size()
@classmethod
def _from_exact_bytes(cls, data: bytes):
return cls(CurvePoint._from_exact_bytes(data))
def __bytes__(self) -> bytes:
return bytes(self._point_key)
@ -122,9 +128,12 @@ class SecretKeyFactory(Serializable):
return SecretKey(scalar_key)
@classmethod
def __take__(cls, data: bytes) -> Tuple['SecretKeyFactory', bytes]:
key_seed, data = cls.__take_bytes__(data, cls._KEY_SEED_SIZE)
return cls(key_seed), data
def serialized_size(cls):
return cls._KEY_SEED_SIZE
@classmethod
def _from_exact_bytes(cls, data: bytes):
return cls(data)
def __bytes__(self) -> bytes:
return bytes(self.__key_seed)

View File

@ -14,41 +14,54 @@ class Serializable(ABC):
"""
Restores the object from serialized bytes.
"""
obj, remainder = cls.__take__(data)
if len(remainder) != 0:
raise ValueError(f"{len(remainder)} bytes remaining after deserializing {cls}")
return obj
expected_size = cls.serialized_size()
if len(data) != expected_size:
raise ValueError(f"Expected {expected_size} bytes, got {len(data)}")
return cls._from_exact_bytes(data)
@classmethod
def __take_bytes__(cls, data: bytes, size: int) -> Tuple[bytes, bytes]:
"""
Takes ``size`` bytes from the bytestring and returns them along with the remainder.
"""
if len(data) < size:
raise ValueError(f"{cls} cannot take {size} bytes "
f"from a bytestring of size {len(data)}")
return data[:size], data[size:]
@classmethod
def __take_types__(cls, data: bytes, *types: Type) -> Tuple[List[Any], bytes]:
@staticmethod
def _split(data: bytes, *types: Type) -> List[Any]:
"""
Given a list of ``Serializable`` types, attempts to deserialize them from the bytestring
one by one and returns the list of the resulting objects and the remaining bytestring.
"""
objs = []
pos = 0
for tp in types:
obj, data = tp.__take__(data)
if issubclass(tp, bool):
size = bool_serialized_size()
else:
size = tp.serialized_size()
chunk = data[pos:pos+size]
if issubclass(tp, bool):
obj = bool_from_exact_bytes(chunk)
else:
obj = tp._from_exact_bytes(chunk)
objs.append(obj)
return objs, data
pos += size
return objs
@classmethod
@abstractmethod
def __take__(cls: Type[_T], data: bytes) -> Tuple[_T, bytes]:
def serialized_size(cls) -> int:
"""
Take however much is necessary from ``data`` and instantiate the object,
returning it and the remaining bytestring.
Returns the size in bytes of the serialized representation of this object
(obtained with ``bytes()``).
"""
raise NotImplementedError
Must be implemented by the derived class.
@classmethod
@abstractmethod
def _from_exact_bytes(cls: Type[_T], data: bytes) -> _T:
"""
Deserializes the object from a bytestring of exactly the expected length
(defined by ``serialized_size()``).
"""
raise NotImplementedError
@ -60,17 +73,20 @@ class Serializable(ABC):
raise NotImplementedError
def serialize_bool(b: bool) -> bytes:
def bool_serialized_size() -> int:
return 1
def bool_bytes(b: bool) -> bytes:
return b'\x01' if b else b'\x00'
def take_bool(data: bytes) -> Tuple[bool, bytes]:
bool_bytes, data = Serializable.__take_bytes__(data, 1)
if bool_bytes == b'\x01':
def bool_from_exact_bytes(data: bytes) -> bool:
if data == b'\x01':
b = True
elif bool_bytes == b'\x00':
elif data == b'\x00':
b = False
else:
raise ValueError("Incorrectly serialized boolean; "
f"expected b'\\x00' or b'\\x01', got {repr(bool_bytes)}")
return b, data
f"expected b'\\x00' or b'\\x01', got {repr(data)}")
return b

View File

@ -92,9 +92,12 @@ class Signature(Serializable):
return self.verify_digest(verifying_key, digest)
@classmethod
def __take__(cls, data):
(r, s), data = cls.__take_types__(data, CurveScalar, CurveScalar)
return cls(r, s), data
def serialized_size(cls):
return CurveScalar.serialized_size() * 2
@classmethod
def _from_exact_bytes(cls, data: bytes):
return cls(*cls._split(data, CurveScalar, CurveScalar))
def __bytes__(self):
return bytes(self.r) + bytes(self.s)