mirror of https://github.com/nucypher/nucypher.git
237 lines
7.6 KiB
Python
237 lines
7.6 KiB
Python
"""
|
|
This file is part of nucypher.
|
|
|
|
nucypher is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU Affero General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
nucypher is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU Affero General Public License for more details.
|
|
|
|
You should have received a copy of the GNU Affero General Public License
|
|
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
|
|
import re
|
|
from typing import Tuple, Any, Type
|
|
|
|
import pytest
|
|
|
|
from nucypher.utilities.versioning import Versioned
|
|
|
|
|
|
def _check_valid_version_tuple(version: Any, cls: Type):
|
|
if not isinstance(version, tuple):
|
|
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
|
|
if not len(version) == Versioned._VERSION_PARTS:
|
|
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._VERSION_PARTS)}-tuple")
|
|
if not all(isinstance(part, int) for part in version):
|
|
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
|
|
|
|
|
|
class A(Versioned):
|
|
|
|
def __init__(self, x: int):
|
|
self.x = x
|
|
|
|
@classmethod
|
|
def _brand(cls):
|
|
return b"ABCD"
|
|
|
|
@classmethod
|
|
def _version(cls) -> Tuple[int, int]:
|
|
return 2, 1
|
|
|
|
def _payload(self) -> bytes:
|
|
return self.x.to_bytes(1, 'big')
|
|
|
|
@classmethod
|
|
def _old_version_handlers(cls):
|
|
return {
|
|
(2, 0): cls._from_bytes_v2_0,
|
|
}
|
|
|
|
@classmethod
|
|
def _from_bytes_v2_0(cls, data):
|
|
# v2.0 saved a 4 byte integer in hex format
|
|
int_hex, remainder = data[:2], data[2:]
|
|
int_bytes = bytes.fromhex(int_hex.decode())
|
|
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
|
|
|
@classmethod
|
|
def _from_bytes_current(cls, data):
|
|
# v2.1 saves a 4 byte integer as 4 bytes
|
|
int_bytes, remainder = data[:1], data[1:]
|
|
return cls(int.from_bytes(int_bytes, 'big')), remainder
|
|
|
|
|
|
def test_unique_branding():
|
|
brands = tuple(v._brand() for v in Versioned.__subclasses__())
|
|
brands_set = set(brands)
|
|
if len(brands) != len(brands_set):
|
|
duplicate_brands = list(brands)
|
|
for brand in brands_set:
|
|
duplicate_brands.remove(brand)
|
|
pytest.fail(f"Duplicated brand(s) {duplicate_brands}.")
|
|
|
|
|
|
def test_valid_branding():
|
|
for cls in Versioned.__subclasses__():
|
|
if len(cls._brand()) != cls._BRAND_SIZE:
|
|
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
|
|
if not re.fullmatch(rb'\w+', cls._brand()):
|
|
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
|
|
|
|
def test_valid_version_implementation():
|
|
for cls in Versioned.__subclasses__():
|
|
_check_valid_version_tuple(version=cls._version(), cls=cls)
|
|
|
|
|
|
def test_valid_old_handlers_index():
|
|
for cls in Versioned.__subclasses__():
|
|
for version in cls._deserializers():
|
|
_check_valid_version_tuple(version=version, cls=cls)
|
|
|
|
|
|
def test_version_metadata():
|
|
major, minor = A._version()
|
|
assert A.version_string() == f'{major}.{minor}'
|
|
|
|
|
|
def test_versioning_header_prepend():
|
|
a = A(1) # stake sauce
|
|
assert a.x == 1
|
|
|
|
serialized = bytes(a)
|
|
assert len(serialized) > Versioned._HEADER_SIZE
|
|
|
|
header = serialized[:Versioned._HEADER_SIZE]
|
|
brand = header[:Versioned._BRAND_SIZE]
|
|
assert brand == A._brand()
|
|
|
|
version = header[Versioned._BRAND_SIZE:]
|
|
major, minor = version[:Versioned._VERSION_PART_SIZE], version[Versioned._VERSION_PART_SIZE:]
|
|
major_number = int.from_bytes(major, 'big')
|
|
minor_number = int.from_bytes(minor, 'big')
|
|
assert (major_number, minor_number) == A._version()
|
|
|
|
|
|
def test_versioning_input_too_short():
|
|
empty = b'ABCD\x00\x01'
|
|
with pytest.raises(ValueError, match='Invalid bytes for A.'):
|
|
A.from_bytes(empty)
|
|
|
|
|
|
def test_versioning_empty_payload():
|
|
empty = b'ABCD\x00\x02\x00\x01'
|
|
with pytest.raises(ValueError, match='No content to deserialize A.'):
|
|
A.from_bytes(empty)
|
|
|
|
|
|
def test_versioning_invalid_brand():
|
|
invalid = b'\x01\x02\x00\x03\x00\x0112'
|
|
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
|
A.from_bytes(invalid)
|
|
|
|
# A partially invalid brand, to check that the regexp validates
|
|
# the whole brand and not just the beginning of it.
|
|
invalid = b'ABC \x00\x02\x00\x0112'
|
|
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
|
|
A.from_bytes(invalid)
|
|
|
|
|
|
def test_versioning_incorrect_brand():
|
|
incorrect = b'ABAB\x00\x0112'
|
|
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'ABCD', Got b'ABAB'."):
|
|
A.from_bytes(incorrect)
|
|
|
|
|
|
def test_unknown_future_major_version():
|
|
empty = b'ABCD\x00\x03\x00\x0212'
|
|
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
|
|
with pytest.raises(ValueError, match=message):
|
|
A.from_bytes(empty)
|
|
|
|
|
|
def test_incompatible_old_major_version(mocker):
|
|
current_spy = mocker.spy(A, "_from_bytes_current")
|
|
v1_data = b'ABCD\x00\x01\x00\x0012'
|
|
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
|
|
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
|
A.from_bytes(v1_data)
|
|
assert not current_spy.call_count
|
|
|
|
|
|
def test_incompatible_future_major_version(mocker):
|
|
current_spy = mocker.spy(A, "_from_bytes_current")
|
|
v1_data = b'ABCD\x00\x03\x00\x0012'
|
|
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
|
|
with pytest.raises(Versioned.IncompatibleVersion, match=message):
|
|
A.from_bytes(v1_data)
|
|
assert not current_spy.call_count
|
|
|
|
|
|
def test_resolve_version():
|
|
# past
|
|
v2_0 = 2, 0
|
|
resolved_version = A._resolve_version(version=v2_0)
|
|
assert resolved_version == v2_0
|
|
|
|
# present
|
|
v2_1 = 2, 1
|
|
resolved_version = A._resolve_version(version=v2_1)
|
|
assert resolved_version == v2_1
|
|
|
|
# future minor version resolves to the latest minor version.
|
|
v2_2 = 2, 2
|
|
resolved_version = A._resolve_version(version=v2_2)
|
|
assert resolved_version == v2_1
|
|
|
|
|
|
def test_old_minor_version_handler_routing(mocker):
|
|
current_spy = mocker.spy(A, "_from_bytes_current")
|
|
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
|
|
|
# Old minor version
|
|
v2_0_data = b'ABCD\x00\x02\x00\x0012'
|
|
a = A.from_bytes(v2_0_data)
|
|
assert a.x == 18
|
|
|
|
# Old minor version was correctly routed to the v2.0 handler.
|
|
assert v2_0_spy.call_count == 1
|
|
v2_0_spy.assert_called_with(b'12')
|
|
assert not current_spy.call_count
|
|
|
|
|
|
def test_current_minor_version_handler_routing(mocker):
|
|
current_spy = mocker.spy(A, "_from_bytes_current")
|
|
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
|
|
|
v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
|
|
a = A.from_bytes(v2_1_data)
|
|
assert a.x == 18
|
|
|
|
# Current version was correctly routed to the v2.1 handler.
|
|
assert current_spy.call_count == 1
|
|
current_spy.assert_called_with(b'\x12')
|
|
assert not v2_0_spy.call_count
|
|
|
|
|
|
def test_future_minor_version_handler_routing(mocker):
|
|
current_spy = mocker.spy(A, "_from_bytes_current")
|
|
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
|
|
|
|
v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
|
|
a = A.from_bytes(v2_2_data)
|
|
assert a.x == 18
|
|
|
|
# Future minor version was correctly routed to
|
|
# the current minor version handler.
|
|
assert current_spy.call_count == 1
|
|
current_spy.assert_called_with(b'\x12')
|
|
assert not v2_0_spy.call_count
|