Refine typing used for LingoList; better define constituent parts.

Use TypeDict either directly via typing or via typing_extensions.
Some cleanup based on mypy output.
pull/3026/head
derekpierre 2022-11-21 18:15:20 -05:00
parent 186c9ad5cd
commit fabc2f76aa
6 changed files with 124 additions and 57 deletions

View File

@ -1,12 +1,12 @@
import json
from abc import ABC, abstractmethod
from base64 import b64decode, b64encode
from typing import Any, Dict, Tuple
from typing import Any, Tuple
from marshmallow import Schema
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.types import ConditionDict
from nucypher.policy.conditions.types import LingoEntry
class _Serializable:
@ -25,7 +25,7 @@ class _Serializable:
instance = schema.load(data)
return instance
def to_dict(self) -> Dict[str, str]:
def to_dict(self):
schema = self.Schema()
data = schema.dump(self)
return data
@ -59,7 +59,7 @@ class ReencryptionCondition(_Serializable, ABC):
return NotImplemented
@classmethod
def validate(cls, data: ConditionDict) -> None:
def validate(cls, data: LingoEntry) -> None:
errors = cls.Schema().validate(data=data)
if errors:
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")

View File

@ -72,10 +72,10 @@ def _resolve_any_context_variables(
processed_parameters.append(p)
v = return_value_test.value
k = return_value_test.key
if is_context_variable(return_value_test.value):
if is_context_variable(v):
v = get_context_value(context_variable=v, **context)
if is_context_variable(return_value_test.key):
k = return_value_test.key
if is_context_variable(k):
k = get_context_value(context_variable=k, **context)
processed_return_value_test = ReturnValueTest(
return_value_test.comparator, value=v, key=k

View File

@ -3,7 +3,7 @@ import base64
import json
import operator as pyoperator
from hashlib import md5
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Iterator, List, Optional, Union
from marshmallow import fields, post_load
@ -14,7 +14,12 @@ from nucypher.policy.conditions.exceptions import (
InvalidLogicalOperator,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, LingoList
from nucypher.policy.conditions.types import (
LingoEntry,
LingoEntryObject,
LingoList,
OperatorDict,
)
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
deserialize_condition_lingo,
@ -23,7 +28,6 @@ from nucypher.policy.conditions.utils import (
class Operator:
OPERATORS = ("and", "or")
_KEY = "operator"
def __init__(self, _operator: str):
if _operator not in self.OPERATORS:
@ -33,13 +37,13 @@ class Operator:
def __str__(self) -> str:
return self.operator
def to_dict(self) -> Dict[str, str]:
return {self._KEY: self.operator}
def to_dict(self) -> OperatorDict:
return {"operator": self.operator}
@classmethod
def from_dict(cls, data: ConditionDict) -> "Operator":
def from_dict(cls, data: OperatorDict) -> "Operator":
cls.validate(data)
instance = cls(_operator=data[cls._KEY])
instance = cls(_operator=data["operator"])
return instance
@classmethod
@ -50,13 +54,13 @@ class Operator:
def to_json(self) -> str:
data = self.to_dict()
data = json.dumps(data)
return data
json_data = json.dumps(data)
return json_data
@classmethod
def validate(cls, data: ConditionDict) -> None:
def validate(cls, data: LingoEntry) -> None:
try:
_operator = data[cls._KEY] # underscore prefix to avoid name shadowing
_operator = data["operator"]
except KeyError:
raise InvalidLogicalOperator(f"Invalid operator data: {data}")
@ -168,7 +172,7 @@ class ConditionLingo:
the Lit Protocol (https://github.com/LIT-Protocol); credit to the authors for inspiring this work.
"""
def __init__(self, conditions: List[Union[ReencryptionCondition, Operator, Any]]):
def __init__(self, conditions: List[LingoEntryObject]):
"""
The input list *must* be structured as follows:
condition
@ -202,7 +206,7 @@ class ConditionLingo:
instance = cls(conditions=conditions)
return instance
def to_list(self): # TODO: __iter__ ?
def to_list(self) -> LingoList: # TODO: __iter__ ?
payload = [c.to_dict() for c in self.conditions]
return payload
@ -212,8 +216,8 @@ class ConditionLingo:
@classmethod
def from_json(cls, data: str) -> 'ConditionLingo':
data = json.loads(data)
instance = cls.from_list(payload=data)
payload = json.loads(data)
instance = cls.from_list(payload=payload)
return instance
def to_base64(self) -> bytes:
@ -222,8 +226,8 @@ class ConditionLingo:
@classmethod
def from_base64(cls, data: bytes) -> 'ConditionLingo':
data = base64.b64decode(data).decode()
instance = cls.from_json(data)
decoded_json = base64.b64decode(data).decode()
instance = cls.from_json(decoded_json)
return instance
def __bytes__(self) -> bytes:
@ -260,5 +264,6 @@ class ConditionLingo:
result = self.__eval(eval_string=eval_string)
return result
OR = Operator('or')
AND = Operator('and')

View File

@ -1,6 +1,66 @@
from typing import Dict, List, Union
import sys
ConditionValues = Union[str, int, bool]
ReturnValueDict = Dict[str, ConditionValues]
ConditionDict = Dict[str, Union[ConditionValues, ReturnValueDict]]
LingoList = List[ConditionDict]
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from typing import Any, Dict, List, Type, Union
from web3.types import ABIFunction
#########
# Context
#########
ContextDict = Dict[str, Any]
################
# ConditionLingo
################
#
# OperatorDict represents:
# - {"operator": "and" | "or"}
class OperatorDict(TypedDict):
operator: str
#
# ConditionDict is a dictionary of:
# - str -> Simple values (str, int, bool), or parameter list which can be anything ('Any')
# - "returnValueTest" -> Return Value Test definitions
# - "functionAbi" -> ABI function definitions (already defined by web3)
#
BaseValue = Union[str, int, bool]
MethodParameters = List[Any]
ConditionValue = Union[BaseValue, MethodParameters] # base value or list of base values
# Return Value Test
class ReturnValueTestDict(TypedDict, total=False):
comparator: str
value: Any
key: Union[str, int]
ConditionDict = Dict[str, Union[ConditionValue, ReturnValueTestDict, ABIFunction]]
#
# LingoEntry is:
# - Condition
# - Operator
#
LingoEntry = Union[OperatorDict, ConditionDict]
#
# LingoList contains a list of LingoEntries
LingoList = List[LingoEntry]
#
# Object Types
#
LingoEntryObjectType = Union[Type["Operator"], Type["ReencryptionCondition"]]
LingoEntryObject = Union["Operator", "ReencryptionCondition"]

View File

@ -1,12 +1,11 @@
import json
import re
from http import HTTPStatus
from typing import Dict, NamedTuple, Optional, Type, Union
from typing import Dict, NamedTuple, Optional, Tuple
from marshmallow import Schema, post_dump
from web3.providers import BaseProvider
from nucypher.policy.conditions.base import ReencryptionCondition
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
ContextVariableVerificationFailed,
@ -17,7 +16,13 @@ from nucypher.policy.conditions.exceptions import (
RequiredContextVariable,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, LingoList
from nucypher.policy.conditions.types import (
ContextDict,
LingoEntry,
LingoEntryObject,
LingoEntryObjectType,
LingoList,
)
from nucypher.utilities.logging import Logger
_ETH = "eth_"
@ -44,7 +49,7 @@ class CamelCaseSchema(Schema):
and snake-case for its internal representation.
"""
SKIP_VALUES = tuple()
SKIP_VALUES: Tuple = tuple()
def on_bind_field(self, field_name, field_obj):
field_obj.data_key = to_camelcase(field_obj.data_key or field_name)
@ -57,8 +62,8 @@ class CamelCaseSchema(Schema):
def resolve_condition_lingo(
data: ConditionDict,
) -> Union[Type["Operator"], Type["ReencryptionCondition"]]:
data: LingoEntry,
) -> LingoEntryObjectType:
"""
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
@ -90,9 +95,7 @@ def resolve_condition_lingo(
)
def deserialize_condition_lingo(
data: Union[str, ConditionDict]
) -> Union["Operator", "ReencryptionCondition"]:
def deserialize_condition_lingo(data: LingoEntry) -> LingoEntryObject:
"""Deserialization helper for condition lingo"""
if isinstance(data, str):
data = json.loads(data)
@ -110,7 +113,7 @@ def validate_condition_lingo(conditions: LingoList) -> None:
def evaluate_condition_lingo(
lingo: "ConditionLingo",
providers: Optional[Dict[int, BaseProvider]] = None,
context: Optional[Dict[Union[str, int], Union[str, int]]] = None,
context: Optional[ContextDict] = None,
log: Logger = __LOGGER,
) -> Optional[EvalError]:
"""
@ -131,49 +134,53 @@ def evaluate_condition_lingo(
result = lingo.eval(providers=providers, **context)
if not result:
# explicit condition failure
error = ("Decryption conditions not satisfied", HTTPStatus.FORBIDDEN)
error = EvalError(
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
)
except ReturnValueEvaluationError as e:
error = (
error = EvalError(
f"Unable to evaluate return value: {e}",
HTTPStatus.BAD_REQUEST,
)
except InvalidConditionLingo as e:
error = (
error = EvalError(
f"Invalid condition grammar: {e}",
HTTPStatus.BAD_REQUEST,
)
except InvalidCondition as e:
error = (
error = EvalError(
f"Incorrect value provided for condition: {e}",
HTTPStatus.BAD_REQUEST,
)
except RequiredContextVariable as e:
# TODO: be more specific and name the missing inputs, etc
error = (f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
error = EvalError(f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
except InvalidContextVariableData as e:
error = (
error = EvalError(
f"Invalid data provided for context variable: {e}",
HTTPStatus.BAD_REQUEST,
)
except ContextVariableVerificationFailed as e:
error = (
error = EvalError(
f"Context variable data could not be verified: {e}",
HTTPStatus.FORBIDDEN,
)
except NoConnectionToChain as e:
error = (
error = EvalError(
f"Node does not have a connection to chain ID {e.chain}: {e}",
HTTPStatus.NOT_IMPLEMENTED,
)
except ConditionEvaluationFailed as e:
error = (f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST)
error = EvalError(
f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST
)
except Exception as e:
# TODO: Unsure why we ended up here
message = (
f"Unexpected exception while evaluating "
f"decryption condition ({e.__class__.__name__}): {e}"
)
error = (message, HTTPStatus.INTERNAL_SERVER_ERROR)
error = EvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
log.warn(message)
if error:

View File

@ -1,15 +1,10 @@
from typing import Dict, Set
from eth_typing import ChecksumAddress
from eth_utils import to_canonical_address
from nucypher_core import Address, Conditions, MessageKit, RetrievalKit
from nucypher_core.umbral import PublicKey, SecretKey, VerifiedCapsuleFrag
from nucypher_core import Address, MessageKit, RetrievalKit
from nucypher.policy.conditions.lingo import ConditionLingo
class PolicyMessageKit:
@ -54,7 +49,7 @@ class PolicyMessageKit:
message_kit=self.message_kit)
@property
def conditions(self) -> ConditionLingo:
def conditions(self) -> Conditions:
return self.message_kit.conditions
@ -71,7 +66,7 @@ class RetrievalResult:
def __init__(self, cfrags: Dict[ChecksumAddress, VerifiedCapsuleFrag]):
self.cfrags = cfrags
def canonical_addresses(self) -> Set[bytes]:
def canonical_addresses(self) -> Set[Address]:
# TODO (#1995): propagate this to use canonical addresses everywhere
return set([Address(to_canonical_address(address)) for address in self.cfrags])