Split out each condition into their own dict type.

pull/3026/head
derekpierre 2022-11-22 13:29:10 -05:00
parent df2dd0819b
commit b2a38762eb
4 changed files with 43 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, Tuple
from marshmallow import Schema
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.types import LingoEntry
from nucypher.policy.conditions.types import LingoListEntry
class _Serializable:
@ -59,7 +59,7 @@ class ReencryptionCondition(_Serializable, ABC):
return NotImplemented
@classmethod
def validate(cls, data: LingoEntry) -> None:
def validate(cls, data: LingoListEntry) -> None:
errors = cls.Schema().validate(data=data)
if errors:
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")

View File

@ -15,9 +15,9 @@ from nucypher.policy.conditions.exceptions import (
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import (
LingoEntry,
LingoEntryObject,
LingoList,
LingoListEntry,
OperatorDict,
)
from nucypher.policy.conditions.utils import (
@ -38,7 +38,11 @@ class Operator:
return self.operator
def to_dict(self) -> OperatorDict:
return {"operator": self.operator}
# strict typing of operator value (must be a literal)
if self.operator == "and":
return {"operator": "and"}
else:
return {"operator": "or"}
@classmethod
def from_dict(cls, data: OperatorDict) -> "Operator":
@ -58,7 +62,7 @@ class Operator:
return json_data
@classmethod
def validate(cls, data: LingoEntry) -> None:
def validate(cls, data: LingoListEntry) -> None:
try:
_operator = data["operator"]
except KeyError:

View File

@ -1,9 +1,9 @@
import sys
if sys.version_info >= (3, 8):
from typing import TypedDict
from typing import Literal, TypedDict
else:
from typing_extensions import TypedDict
from typing_extensions import Literal, TypedDict
from typing import Any, Dict, List, Type, Union
@ -23,20 +23,14 @@ ContextDict = Dict[str, Any]
# OperatorDict represents:
# - {"operator": "and" | "or"}
class OperatorDict(TypedDict):
operator: str
operator: Literal["and", "or"]
#
# ConditionDict is a dictionary of:
# - str -> Simple values (str, int, bool), or parameter list which can be anything ('Any')
# - "returnValueTest" -> Return Value Test definitions
# - "functionAbi" -> ABI function definitions (already defined by web3)
#
BaseValue = Union[str, int, bool]
MethodParameters = List[Any]
ConditionValue = Union[BaseValue, MethodParameters] # base value or list of base values
# - TimeCondition
# - RPCCondition
# - ContractCondition
# Return Value Test
class ReturnValueTestDict(TypedDict, total=False):
@ -45,18 +39,40 @@ class ReturnValueTestDict(TypedDict, total=False):
key: Union[str, int]
ConditionDict = Dict[str, Union[ConditionValue, ReturnValueTestDict, ABIFunction]]
class _ReencryptionConditionDict(TypedDict, total=False):
name: str
class TimeConditionDict(_ReencryptionConditionDict, total=False):
method: Literal["timelock"]
returnValueTest: ReturnValueTestDict
class RPCConditionDict(_ReencryptionConditionDict, total=False):
chain: int
method: str
parameters: List[Any]
returnValueTest: ReturnValueTestDict
class ContractConditionDict(RPCConditionDict, total=False):
standardContractType: str
contractAddress: str
functionAbi: ABIFunction
ConditionDict = Union[TimeConditionDict, RPCConditionDict, ContractConditionDict]
#
# LingoEntry is:
# - Condition
# - Operator
#
LingoEntry = Union[OperatorDict, ConditionDict]
LingoListEntry = Union[OperatorDict, ConditionDict]
#
# LingoList contains a list of LingoEntries
LingoList = List[LingoEntry]
LingoList = List[LingoListEntry]
#

View File

@ -18,10 +18,10 @@ from nucypher.policy.conditions.exceptions import (
)
from nucypher.policy.conditions.types import (
ContextDict,
LingoEntry,
LingoEntryObject,
LingoEntryObjectType,
LingoList,
LingoListEntry,
)
from nucypher.utilities.logging import Logger
@ -62,7 +62,7 @@ class CamelCaseSchema(Schema):
def resolve_condition_lingo(
data: LingoEntry,
data: LingoListEntry,
) -> LingoEntryObjectType:
"""
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
@ -95,7 +95,7 @@ def resolve_condition_lingo(
)
def deserialize_condition_lingo(data: LingoEntry) -> LingoEntryObject:
def deserialize_condition_lingo(data: LingoListEntry) -> LingoEntryObject:
"""Deserialization helper for condition lingo"""
if isinstance(data, str):
data = json.loads(data)