Split out each condition into their own dict type.

pull/3026/head
derekpierre 2022-11-22 13:29:10 -05:00
parent df2dd0819b
commit b2a38762eb
4 changed files with 43 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, Tuple
from marshmallow import Schema from marshmallow import Schema
from nucypher.policy.conditions.exceptions import InvalidCondition from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.types import LingoEntry from nucypher.policy.conditions.types import LingoListEntry
class _Serializable: class _Serializable:
@ -59,7 +59,7 @@ class ReencryptionCondition(_Serializable, ABC):
return NotImplemented return NotImplemented
@classmethod @classmethod
def validate(cls, data: LingoEntry) -> None: def validate(cls, data: LingoListEntry) -> None:
errors = cls.Schema().validate(data=data) errors = cls.Schema().validate(data=data)
if errors: if errors:
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}") raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")

View File

@ -15,9 +15,9 @@ from nucypher.policy.conditions.exceptions import (
ReturnValueEvaluationError, ReturnValueEvaluationError,
) )
from nucypher.policy.conditions.types import ( from nucypher.policy.conditions.types import (
LingoEntry,
LingoEntryObject, LingoEntryObject,
LingoList, LingoList,
LingoListEntry,
OperatorDict, OperatorDict,
) )
from nucypher.policy.conditions.utils import ( from nucypher.policy.conditions.utils import (
@ -38,7 +38,11 @@ class Operator:
return self.operator return self.operator
def to_dict(self) -> OperatorDict: def to_dict(self) -> OperatorDict:
return {"operator": self.operator} # strict typing of operator value (must be a literal)
if self.operator == "and":
return {"operator": "and"}
else:
return {"operator": "or"}
@classmethod @classmethod
def from_dict(cls, data: OperatorDict) -> "Operator": def from_dict(cls, data: OperatorDict) -> "Operator":
@ -58,7 +62,7 @@ class Operator:
return json_data return json_data
@classmethod @classmethod
def validate(cls, data: LingoEntry) -> None: def validate(cls, data: LingoListEntry) -> None:
try: try:
_operator = data["operator"] _operator = data["operator"]
except KeyError: except KeyError:

View File

@ -1,9 +1,9 @@
import sys import sys
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import TypedDict from typing import Literal, TypedDict
else: else:
from typing_extensions import TypedDict from typing_extensions import Literal, TypedDict
from typing import Any, Dict, List, Type, Union from typing import Any, Dict, List, Type, Union
@ -23,20 +23,14 @@ ContextDict = Dict[str, Any]
# OperatorDict represents: # OperatorDict represents:
# - {"operator": "and" | "or"} # - {"operator": "and" | "or"}
class OperatorDict(TypedDict): class OperatorDict(TypedDict):
operator: str operator: Literal["and", "or"]
# #
# ConditionDict is a dictionary of: # ConditionDict is a dictionary of:
# - str -> Simple values (str, int, bool), or parameter list which can be anything ('Any') # - TimeCondition
# - "returnValueTest" -> Return Value Test definitions # - RPCCondition
# - "functionAbi" -> ABI function definitions (already defined by web3) # - ContractCondition
#
BaseValue = Union[str, int, bool]
MethodParameters = List[Any]
ConditionValue = Union[BaseValue, MethodParameters] # base value or list of base values
# Return Value Test # Return Value Test
class ReturnValueTestDict(TypedDict, total=False): class ReturnValueTestDict(TypedDict, total=False):
@ -45,18 +39,40 @@ class ReturnValueTestDict(TypedDict, total=False):
key: Union[str, int] key: Union[str, int]
ConditionDict = Dict[str, Union[ConditionValue, ReturnValueTestDict, ABIFunction]] class _ReencryptionConditionDict(TypedDict, total=False):
name: str
class TimeConditionDict(_ReencryptionConditionDict, total=False):
method: Literal["timelock"]
returnValueTest: ReturnValueTestDict
class RPCConditionDict(_ReencryptionConditionDict, total=False):
chain: int
method: str
parameters: List[Any]
returnValueTest: ReturnValueTestDict
class ContractConditionDict(RPCConditionDict, total=False):
standardContractType: str
contractAddress: str
functionAbi: ABIFunction
ConditionDict = Union[TimeConditionDict, RPCConditionDict, ContractConditionDict]
# #
# LingoEntry is: # LingoEntry is:
# - Condition # - Condition
# - Operator # - Operator
# #
LingoEntry = Union[OperatorDict, ConditionDict] LingoListEntry = Union[OperatorDict, ConditionDict]
# #
# LingoList contains a list of LingoEntries # LingoList contains a list of LingoEntries
LingoList = List[LingoEntry] LingoList = List[LingoListEntry]
# #

View File

@ -18,10 +18,10 @@ from nucypher.policy.conditions.exceptions import (
) )
from nucypher.policy.conditions.types import ( from nucypher.policy.conditions.types import (
ContextDict, ContextDict,
LingoEntry,
LingoEntryObject, LingoEntryObject,
LingoEntryObjectType, LingoEntryObjectType,
LingoList, LingoList,
LingoListEntry,
) )
from nucypher.utilities.logging import Logger from nucypher.utilities.logging import Logger
@ -62,7 +62,7 @@ class CamelCaseSchema(Schema):
def resolve_condition_lingo( def resolve_condition_lingo(
data: LingoEntry, data: LingoListEntry,
) -> LingoEntryObjectType: ) -> LingoEntryObjectType:
""" """
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works. TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
@ -95,7 +95,7 @@ def resolve_condition_lingo(
) )
def deserialize_condition_lingo(data: LingoEntry) -> LingoEntryObject: def deserialize_condition_lingo(data: LingoListEntry) -> LingoEntryObject:
"""Deserialization helper for condition lingo""" """Deserialization helper for condition lingo"""
if isinstance(data, str): if isinstance(data, str):
data = json.loads(data) data = json.loads(data)