nucypher/tests/unit/test_versioning.py

237 lines
7.6 KiB
Python

"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import re
from typing import Tuple, Any, Type
import pytest
from nucypher.utilities.versioning import Versioned
def _check_valid_version_tuple(version: Any, cls: Type):
if not isinstance(version, tuple):
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a tuple")
if not len(version) == Versioned._VERSION_PARTS:
pytest.fail(f"Old version handlers keys for {cls.__name__} must be a {str(Versioned._VERSION_PARTS)}-tuple")
if not all(isinstance(part, int) for part in version):
pytest.fail(f"Old version handlers version parts {cls.__name__} must be integers")
class A(Versioned):
def __init__(self, x: int):
self.x = x
@classmethod
def _brand(cls):
return b"ABCD"
@classmethod
def _version(cls) -> Tuple[int, int]:
return 2, 1
def _payload(self) -> bytes:
return self.x.to_bytes(1, 'big')
@classmethod
def _old_version_handlers(cls):
return {
(2, 0): cls._from_bytes_v2_0,
}
@classmethod
def _from_bytes_v2_0(cls, data):
# v2.0 saved a 4 byte integer in hex format
int_hex, remainder = data[:2], data[2:]
int_bytes = bytes.fromhex(int_hex.decode())
return cls(int.from_bytes(int_bytes, 'big')), remainder
@classmethod
def _from_bytes_current(cls, data):
# v2.1 saves a 4 byte integer as 4 bytes
int_bytes, remainder = data[:1], data[1:]
return cls(int.from_bytes(int_bytes, 'big')), remainder
def test_unique_branding():
brands = tuple(v._brand() for v in Versioned.__subclasses__())
brands_set = set(brands)
if len(brands) != len(brands_set):
duplicate_brands = list(brands)
for brand in brands_set:
duplicate_brands.remove(brand)
pytest.fail(f"Duplicated brand(s) {duplicate_brands}.")
def test_valid_branding():
for cls in Versioned.__subclasses__():
if len(cls._brand()) != cls._BRAND_SIZE:
pytest.fail(f"Brand must be exactly {str(Versioned._BRAND_SIZE)} bytes.")
if not re.fullmatch(rb'\w+', cls._brand()):
pytest.fail(f"Brand must be alphanumeric; Got {cls._brand()}")
def test_valid_version_implementation():
for cls in Versioned.__subclasses__():
_check_valid_version_tuple(version=cls._version(), cls=cls)
def test_valid_old_handlers_index():
for cls in Versioned.__subclasses__():
for version in cls._deserializers():
_check_valid_version_tuple(version=version, cls=cls)
def test_version_metadata():
major, minor = A._version()
assert A.version_string() == f'{major}.{minor}'
def test_versioning_header_prepend():
a = A(1) # stake sauce
assert a.x == 1
serialized = bytes(a)
assert len(serialized) > Versioned._HEADER_SIZE
header = serialized[:Versioned._HEADER_SIZE]
brand = header[:Versioned._BRAND_SIZE]
assert brand == A._brand()
version = header[Versioned._BRAND_SIZE:]
major, minor = version[:Versioned._VERSION_PART_SIZE], version[Versioned._VERSION_PART_SIZE:]
major_number = int.from_bytes(major, 'big')
minor_number = int.from_bytes(minor, 'big')
assert (major_number, minor_number) == A._version()
def test_versioning_input_too_short():
empty = b'ABCD\x00\x01'
with pytest.raises(ValueError, match='Invalid bytes for A.'):
A.from_bytes(empty)
def test_versioning_empty_payload():
empty = b'ABCD\x00\x02\x00\x01'
with pytest.raises(ValueError, match='No content to deserialize A.'):
A.from_bytes(empty)
def test_versioning_invalid_brand():
invalid = b'\x01\x02\x00\x03\x00\x0112'
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
A.from_bytes(invalid)
# A partially invalid brand, to check that the regexp validates
# the whole brand and not just the beginning of it.
invalid = b'ABC \x00\x02\x00\x0112'
with pytest.raises(Versioned.InvalidHeader, match="Incompatible bytes for A."):
A.from_bytes(invalid)
def test_versioning_incorrect_brand():
incorrect = b'ABAB\x00\x0112'
with pytest.raises(Versioned.InvalidHeader, match="Incorrect brand. Expected b'ABCD', Got b'ABAB'."):
A.from_bytes(incorrect)
def test_unknown_future_major_version():
empty = b'ABCD\x00\x03\x00\x0212'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.2.'
with pytest.raises(ValueError, match=message):
A.from_bytes(empty)
def test_incompatible_old_major_version(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v1_data = b'ABCD\x00\x01\x00\x0012'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 1.0.'
with pytest.raises(Versioned.IncompatibleVersion, match=message):
A.from_bytes(v1_data)
assert not current_spy.call_count
def test_incompatible_future_major_version(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v1_data = b'ABCD\x00\x03\x00\x0012'
message = 'Incompatible versioned bytes for A. Compatible version is 2.x, Got 3.0.'
with pytest.raises(Versioned.IncompatibleVersion, match=message):
A.from_bytes(v1_data)
assert not current_spy.call_count
def test_resolve_version():
# past
v2_0 = 2, 0
resolved_version = A._resolve_version(version=v2_0)
assert resolved_version == v2_0
# present
v2_1 = 2, 1
resolved_version = A._resolve_version(version=v2_1)
assert resolved_version == v2_1
# future minor version resolves to the latest minor version.
v2_2 = 2, 2
resolved_version = A._resolve_version(version=v2_2)
assert resolved_version == v2_1
def test_old_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
# Old minor version
v2_0_data = b'ABCD\x00\x02\x00\x0012'
a = A.from_bytes(v2_0_data)
assert a.x == 18
# Old minor version was correctly routed to the v2.0 handler.
assert v2_0_spy.call_count == 1
v2_0_spy.assert_called_with(b'12')
assert not current_spy.call_count
def test_current_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_1_data = b'ABCD\x00\x02\x00\x01\x12'
a = A.from_bytes(v2_1_data)
assert a.x == 18
# Current version was correctly routed to the v2.1 handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count
def test_future_minor_version_handler_routing(mocker):
current_spy = mocker.spy(A, "_from_bytes_current")
v2_0_spy = mocker.spy(A, "_from_bytes_v2_0")
v2_2_data = b'ABCD\x00\x02\x02\x01\x12'
a = A.from_bytes(v2_2_data)
assert a.x == 18
# Future minor version was correctly routed to
# the current minor version handler.
assert current_spy.call_count == 1
current_spy.assert_called_with(b'\x12')
assert not v2_0_spy.call_count