Use semver for lingo versioning. Ensure that major version changes are recognized as incompatible.

pull/3145/head
derekpierre 2023-06-13 19:28:08 -04:00
parent d69b62d8b3
commit 53955cb80b
4 changed files with 65 additions and 17 deletions

View File

@ -67,14 +67,14 @@ class AccessControlCondition(_Serializable, ABC):
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")
@classmethod
def from_dict(cls, data) -> "_Serializable":
def from_dict(cls, data) -> "AccessControlCondition":
try:
return super().from_dict(data)
except ValidationError as e:
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
@classmethod
def from_json(cls, data) -> "_Serializable":
def from_json(cls, data) -> "AccessControlCondition":
try:
return super().from_json(data)
except ValidationError as e:

View File

@ -4,7 +4,16 @@ import operator as pyoperator
from hashlib import md5
from typing import Any, List, Optional, Tuple
from marshmallow import Schema, ValidationError, fields, post_load, pre_load, validate
from marshmallow import (
Schema,
ValidationError,
fields,
post_load,
pre_load,
validate,
validates,
)
from packaging.version import parse as parse_version
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
from nucypher.policy.conditions.context import is_context_variable
@ -212,16 +221,20 @@ class ReturnValueTest:
class ConditionLingo(_Serializable):
VERSION = 1
VERSION = "1.0.0"
class Schema(Schema):
version = fields.Int(required=True) # TODO validation here
version = fields.Str(required=True)
condition = _ConditionField(required=True)
# maintain field declaration ordering
class Meta:
ordered = True
@validates("version")
def validate_version(self, version):
ConditionLingo.check_version_compatibility(version)
@pre_load
def set_lingo_version(self, data, **kwargs):
version = data.get("version")
@ -239,7 +252,7 @@ class ConditionLingo(_Serializable):
the Lit Protocol (https://github.com/LIT-Protocol); credit to the authors for inspiring this work.
"""
def __init__(self, condition: AccessControlCondition, version: int = VERSION):
def __init__(self, condition: AccessControlCondition, version: str = VERSION):
"""
CONDITION = BASE_CONDITION | COMPOUND_CONDITION
BASE_CONDITION = {
@ -251,10 +264,7 @@ class ConditionLingo(_Serializable):
}
"""
self.condition = condition
if version > self.VERSION:
raise ValueError(
f"Version provided is in the future {version} > {self.VERSION}"
)
self.check_version_compatibility(version)
self.version = version
self.id = md5(bytes(self)).hexdigest()[:6]
@ -312,11 +322,6 @@ class ConditionLingo(_Serializable):
from nucypher.policy.conditions.time import TimeCondition
# version logical adjustments can be made here as required
if version and version > ConditionLingo.VERSION:
raise InvalidConditionLingo(
f"Version is in the future: {version} > {ConditionLingo.VERSION}"
)
# Inspect
method = condition.get("method")
operator = condition.get("operator")
@ -338,3 +343,10 @@ class ConditionLingo(_Serializable):
raise InvalidConditionLingo(
f"Cannot resolve condition lingo type from data {condition}"
)
@classmethod
def check_version_compatibility(cls, version: str):
if parse_version(version).major > parse_version(cls.VERSION).major:
raise InvalidConditionLingo(
f"Version provided, {version}, is incompatible with current version {cls.VERSION}"
)

View File

@ -84,5 +84,5 @@ ConditionDict = Union[
# - condition
# - ConditionDict
class Lingo(TypedDict):
version: int
version: str
condition: ConditionDict

View File

@ -1,6 +1,7 @@
import json
import pytest
from packaging.version import parse as parse_version
import nucypher
from nucypher.blockchain.eth.constants import NULL_ADDRESS
@ -32,13 +33,16 @@ def lingo():
}
def test_invalid_condition():
def test_invalid_condition(lingo):
# no version or condition
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict({})
# no condition
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict({"version": ConditionLingo.VERSION})
# invalid condition
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(
{
@ -65,6 +69,38 @@ def test_invalid_condition():
ConditionLingo.from_dict(invalid_operator_position_lingo)
@pytest.mark.parametrize("case", ["major", "minor", "patch"])
def test_invalid_condition_version(case):
# version in the future
current_version = parse_version(ConditionLingo.VERSION)
major = current_version.major
minor = current_version.minor
patch = current_version.micro
if case == "major":
major += 1
elif case == "minor":
minor += 1
else:
patch += 1
newer_version_string = f"{major}.{minor}.{patch}"
lingo_dict = {
"version": newer_version_string,
"condition": {
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
}
if case == "major":
# exception should be thrown since incompatible:
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(lingo_dict)
else:
# no exception thrown
ConditionLingo.from_dict(lingo_dict)
def test_condition_lingo_to_from_dict(lingo):
clingo = ConditionLingo.from_dict(lingo)
clingo_dict = clingo.to_dict()