mirror of https://github.com/nucypher/nucypher.git
commit
957250c6d8
|
@ -0,0 +1 @@
|
|||
Extend brand size in ``Versioned`` to 4 bytes
|
|
@ -192,7 +192,7 @@ class ReencryptionRequest(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RQ'
|
||||
return b'ReRq'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
@ -244,7 +244,7 @@ class ReencryptionResponse(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RR'
|
||||
return b'ReRs'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
|
|
@ -122,7 +122,7 @@ class MessageKit(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'MK'
|
||||
return b'MKit'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
@ -181,7 +181,7 @@ class RetrievalKit(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RK'
|
||||
return b'RKit'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
|
|
@ -105,7 +105,7 @@ class TreasureMap(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'TM'
|
||||
return b'TMap'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
@ -205,7 +205,7 @@ class AuthorizedKeyFrag(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'KF'
|
||||
return b'AKF_'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
@ -330,7 +330,7 @@ class EncryptedTreasureMap(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'EM'
|
||||
return b'EMap'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
|
|
@ -58,7 +58,7 @@ class Arrangement(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'AR'
|
||||
return b'Arng'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
|
|
@ -74,7 +74,7 @@ class RevocationOrder(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls) -> bytes:
|
||||
return b'RV'
|
||||
return b'Revo'
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
|
|
@ -17,16 +17,17 @@
|
|||
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
import re
|
||||
from typing import Dict, Tuple, Callable
|
||||
|
||||
|
||||
class Versioned(ABC):
|
||||
"""Base class for serializable entities"""
|
||||
|
||||
_PARTS = 2 # bytes
|
||||
_PART_SIZE = 2
|
||||
_BRAND_SIZE = 2
|
||||
_VERSION_SIZE = _PART_SIZE * _PARTS
|
||||
_VERSION_PARTS = 2
|
||||
_VERSION_PART_SIZE = 2 # bytes
|
||||
_BRAND_SIZE = 4
|
||||
_VERSION_SIZE = _VERSION_PART_SIZE * _VERSION_PARTS
|
||||
_HEADER_SIZE = _BRAND_SIZE + _VERSION_SIZE
|
||||
|
||||
class InvalidHeader(ValueError):
|
||||
|
@ -65,8 +66,8 @@ class Versioned(ABC):
|
|||
def _header(cls) -> bytes:
|
||||
"""The entire bytes header to prepend to the instance payload."""
|
||||
major, minor = cls._version()
|
||||
major_bytes = major.to_bytes(cls._PART_SIZE, 'big')
|
||||
minor_bytes = minor.to_bytes(cls._PART_SIZE, 'big')
|
||||
major_bytes = major.to_bytes(cls._VERSION_PART_SIZE, 'big')
|
||||
minor_bytes = minor.to_bytes(cls._VERSION_PART_SIZE, 'big')
|
||||
header = cls._brand() + major_bytes + minor_bytes
|
||||
return header
|
||||
|
||||
|
@ -135,7 +136,7 @@ class Versioned(ABC):
|
|||
brand = data[:cls._BRAND_SIZE]
|
||||
if brand != cls._brand():
|
||||
error = f"Incorrect brand. Expected {cls._brand()}, Got {brand}."
|
||||
if not brand.isalpha():
|
||||
if not re.fullmatch(rb'\w+', brand):
|
||||
# unversioned entities for older versions will most likely land here.
|
||||
error = f"Incompatible bytes for {cls.__name__}."
|
||||
raise cls.InvalidHeader(error)
|
||||
|
@ -144,7 +145,7 @@ class Versioned(ABC):
|
|||
@classmethod
|
||||
def _parse_version(cls, data: bytes) -> Tuple[int, int]:
|
||||
version_data = data[cls._BRAND_SIZE:cls._HEADER_SIZE]
|
||||
major, minor = version_data[:cls._PART_SIZE], version_data[cls._PART_SIZE:]
|
||||
major, minor = version_data[:cls._VERSION_PART_SIZE], version_data[cls._VERSION_PART_SIZE:]
|
||||
major, minor = int.from_bytes(major, 'big'), int.from_bytes(minor, 'big')
|
||||
version = major, minor
|
||||
return version
|
||||
|
|
|
@ -90,12 +90,12 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
assert "Could not parse tmap" in str(e)
|
||||
assert "Invalid base64-encoded string" in str(e)
|
||||
|
||||
base64_header = base64.b64encode(EncryptedTreasureMapClass._header()).decode()
|
||||
|
||||
# valid base64 but invalid treasuremap
|
||||
bad_map = base64_header + "VGhpcyBpcWgb3RhbGx5IG5vdCBhIHRyZWFzdXJlbWg=="
|
||||
bad_map = EncryptedTreasureMapClass._header() + b"your face looks like a treasure map"
|
||||
bad_map_b64 = base64.b64encode(bad_map).decode()
|
||||
|
||||
with pytest.raises(InvalidInputData) as e:
|
||||
EncryptedTreasureMapsOnly().load({'tmap': bad_map})
|
||||
EncryptedTreasureMapsOnly().load({'tmap': bad_map_b64})
|
||||
|
||||
assert "Could not convert input for tmap to an EncryptedTreasureMap" in str(e)
|
||||
assert "Invalid encrypted treasure map contents." in str(e)
|
||||
|
@ -121,10 +121,11 @@ def test_treasure_map_validation(enacted_federated_policy,
|
|||
assert "Invalid base64-encoded string" in str(e)
|
||||
|
||||
# valid base64 but invalid treasuremap
|
||||
base64_header = base64.b64encode(TreasureMapClass._header()).decode()
|
||||
bad_map = base64_header + "VGhpcyBpcyB0b3RhbGx5IG5vdCBhIHRyZWFzdXJlbWFwLg=="
|
||||
bad_map = TreasureMapClass._header() + b"your face looks like a treasure map"
|
||||
bad_map_b64 = base64.b64encode(bad_map).decode()
|
||||
|
||||
with pytest.raises(InvalidInputData) as e:
|
||||
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map})
|
||||
UnenncryptedTreasureMapsOnly().load({'tmap': bad_map_b64})
|
||||
|
||||
assert "Could not convert input for tmap to a TreasureMap" in str(e)
|
||||
assert "Invalid treasure map contents." in str(e)
|
||||
|
@ -153,9 +154,11 @@ def test_messagekit_validation(capsule_side_channel):
|
|||
assert "Incorrect padding" in str(e)
|
||||
|
||||
# valid base64 but invalid messagekit
|
||||
b64header = base64.b64encode(MessageKit._header()).decode()
|
||||
bad_kit = MessageKit._header() + b"I got a message for you"
|
||||
bad_kit_b64 = base64.b64encode(bad_kit).decode()
|
||||
|
||||
with pytest.raises(SpecificationError) as e:
|
||||
MessageKitsOnly().load({'mkit': b64header + "V3da=="})
|
||||
MessageKitsOnly().load({'mkit': bad_kit_b64})
|
||||
|
||||
assert "Could not parse mkit" in str(e)
|
||||
assert "Not enough bytes to constitute message types" in str(e)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""
|
||||
|
||||
|
||||
import re
|
||||
from typing import Tuple, Any, Type
|
||||
|
||||
import pytest
|
||||
|
@ -26,8 +27,8 @@ from nucypher.utilities.versioning import Versioned
|
|||
def _check_valid_version_tuple(version: Any, cls: Type):
|
||||
if not isinstance(version, tuple):
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
|
||||
if not len(version) == Versioned._PARTS:
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._PARTS)}-tuple")
|
||||
if not len(version) == Versioned._VERSION_PARTS:
|
||||
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._VERSION_PARTS)}-tuple")
|
||||
if not all(isinstance(part, int) for part in version):
|
||||
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
|
||||
|
||||
|
@ -39,7 +40,7 @@ class A(Versioned):
|
|||
|
||||
@classmethod
|
||||
def _brand(cls):
|
||||
return b"AA"
|
||||
return b"ABCD"
|
||||
|
||||
@classmethod
|
||||
def _version(cls) -> Tuple[int, int]:
|
||||
|
@ -77,10 +78,9 @@ def test_valid_branding():
|
|||
for cls in Versioned.__subclasses__():
|
||||
if len(cls._brand()) != cls._BRAND_SIZE:
|
||||
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
|
||||
if not cls._brand().isalpha():
|
||||
if not re.fullmatch(rb'\w+', cls._brand()):
|
||||
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
|
||||
|
||||
|
||||
def test_valid_version_implementation():
|
||||
for cls in Versioned.__subclasses__():
|
||||
_check_valid_version_tuple(version=cls._version(), cls=cls)
|
||||
|
@ -109,38 +109,44 @@ def test_versioning_header_prepend():
|
|||
assert brand == A._brand()
|
||||
|
||||
version = header[Versioned._BRAND_SIZE:]
|
||||
major, minor = version[:Versioned._PART_SIZE], version[Versioned._PART_SIZE:]
|
||||
major, minor = version[:Versioned._VERSION_PART_SIZE], version[Versioned._VERSION_PART_SIZE:]
|
||||
major_number = int.from_bytes(major, 'big')
|
||||
minor_number = int.from_bytes(minor, 'big')
|
||||
assert (major_number, minor_number) == A._version()
|
||||
|
||||
|
||||
def test_versioning_input_too_short():
|
||||
empty = b'AA\x00\x01'
|
||||
empty = b'ABCD\x00\x01'
|
||||
with pytest.raises(ValueError, match='Invalid bytes for A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_empty_payload():
|
||||
empty = b'AA\x00\x02\x00\x01'
|
||||
empty = b'ABCD\x00\x02\x00\x01'
|
||||
with pytest.raises(ValueError, match='No content to deserialize A.'):
|
||||
A.from_bytes(empty)
|
||||
|
||||
|
||||
def test_versioning_invalid_brand():
|
||||
invalid = b'\x00\x03\x00\x0112'
|
||||
invalid = b'\x01\x02\x00\x03\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
||||
A.from_bytes(invalid)
|
||||
|
||||
# A partially invalid brand, to check that the regexp validates
|
||||
# the whole brand and not just the beginning of it.
|
||||
invalid = b'ABC \x00\x02\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
||||
A.from_bytes(invalid)
|
||||
|
||||
|
||||
def test_versioning_incorrect_brand():
|
||||
incorrect = b'AB\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'AA', Got b'AB'."):
|
||||
incorrect = b'ABAB\x00\x0112'
|
||||
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'ABCD', Got b'ABAB'."):
|
||||
A.from_bytes(incorrect)
|
||||
|
||||
|
||||
def test_unknown_future_major_version():
|
||||
empty = b'AA\x00\x03\x00\x0212'
|
||||
empty = b'ABCD\x00\x03\x00\x0212'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
A.from_bytes(empty)
|
||||
|
@ -148,7 +154,7 @@ def test_unknown_future_major_version():
|
|||
|
||||
def test_incompatible_old_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'AA\x00\x01\x00\x0012'
|
||||
v1_data = b'ABCD\x00\x01\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
|
@ -157,7 +163,7 @@ def test_incompatible_old_major_version(mocker):
|
|||
|
||||
def test_incompatible_future_major_version(mocker):
|
||||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v1_data = b'AA\x00\x03\x00\x0012'
|
||||
v1_data = b'ABCD\x00\x03\x00\x0012'
|
||||
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
|
||||
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
||||
A.from_bytes(v1_data)
|
||||
|
@ -186,7 +192,7 @@ def test_old_minor_version_handler_routing(mocker):
|
|||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
# Old minor version
|
||||
v2_0_data = b'AA\x00\x02\x00\x0012'
|
||||
v2_0_data = b'ABCD\x00\x02\x00\x0012'
|
||||
a = A.from_bytes(v2_0_data)
|
||||
assert a.x == 18
|
||||
|
||||
|
@ -200,7 +206,7 @@ def test_current_minor_version_handler_routing(mocker):
|
|||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_1_data = b'AA\x00\x02\x00\x0112'
|
||||
v2_1_data = b'ABCD\x00\x02\x00\x0112'
|
||||
a = A.from_bytes(v2_1_data)
|
||||
assert a.x == '18'
|
||||
|
||||
|
@ -214,7 +220,7 @@ def test_future_minor_version_handler_routing(mocker):
|
|||
current_spy = mocker.spy(A, "_from_bytes_current")
|
||||
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
||||
|
||||
v2_2_data = b'AA\x00\x02\x02\x0112'
|
||||
v2_2_data = b'ABCD\x00\x02\x02\x0112'
|
||||
a = A.from_bytes(v2_2_data)
|
||||
assert a.x == '18'
|
||||
|
||||
|
|
Loading…
Reference in New Issue