mirror of https://github.com/nucypher/nucypher.git
Refine typing used for LingoList; better define constituent parts.
Use TypeDict either directly via typing or via typing_extensions. Some cleanup based on mypy output.pull/3026/head
parent
186c9ad5cd
commit
fabc2f76aa
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Tuple
|
||||
|
||||
from marshmallow import Schema
|
||||
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.types import ConditionDict
|
||||
from nucypher.policy.conditions.types import LingoEntry
|
||||
|
||||
|
||||
class _Serializable:
|
||||
|
@ -25,7 +25,7 @@ class _Serializable:
|
|||
instance = schema.load(data)
|
||||
return instance
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
def to_dict(self):
|
||||
schema = self.Schema()
|
||||
data = schema.dump(self)
|
||||
return data
|
||||
|
@ -59,7 +59,7 @@ class ReencryptionCondition(_Serializable, ABC):
|
|||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data: ConditionDict) -> None:
|
||||
def validate(cls, data: LingoEntry) -> None:
|
||||
errors = cls.Schema().validate(data=data)
|
||||
if errors:
|
||||
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")
|
||||
|
|
|
@ -72,10 +72,10 @@ def _resolve_any_context_variables(
|
|||
processed_parameters.append(p)
|
||||
|
||||
v = return_value_test.value
|
||||
k = return_value_test.key
|
||||
if is_context_variable(return_value_test.value):
|
||||
if is_context_variable(v):
|
||||
v = get_context_value(context_variable=v, **context)
|
||||
if is_context_variable(return_value_test.key):
|
||||
k = return_value_test.key
|
||||
if is_context_variable(k):
|
||||
k = get_context_value(context_variable=k, **context)
|
||||
processed_return_value_test = ReturnValueTest(
|
||||
return_value_test.comparator, value=v, key=k
|
||||
|
|
|
@ -3,7 +3,7 @@ import base64
|
|||
import json
|
||||
import operator as pyoperator
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, Iterator, List, Optional, Union
|
||||
|
||||
from marshmallow import fields, post_load
|
||||
|
||||
|
@ -14,7 +14,12 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidLogicalOperator,
|
||||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.policy.conditions.types import ConditionDict, LingoList
|
||||
from nucypher.policy.conditions.types import (
|
||||
LingoEntry,
|
||||
LingoEntryObject,
|
||||
LingoList,
|
||||
OperatorDict,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import (
|
||||
CamelCaseSchema,
|
||||
deserialize_condition_lingo,
|
||||
|
@ -23,7 +28,6 @@ from nucypher.policy.conditions.utils import (
|
|||
|
||||
class Operator:
|
||||
OPERATORS = ("and", "or")
|
||||
_KEY = "operator"
|
||||
|
||||
def __init__(self, _operator: str):
|
||||
if _operator not in self.OPERATORS:
|
||||
|
@ -33,13 +37,13 @@ class Operator:
|
|||
def __str__(self) -> str:
|
||||
return self.operator
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
return {self._KEY: self.operator}
|
||||
def to_dict(self) -> OperatorDict:
|
||||
return {"operator": self.operator}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: ConditionDict) -> "Operator":
|
||||
def from_dict(cls, data: OperatorDict) -> "Operator":
|
||||
cls.validate(data)
|
||||
instance = cls(_operator=data[cls._KEY])
|
||||
instance = cls(_operator=data["operator"])
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
|
@ -50,13 +54,13 @@ class Operator:
|
|||
|
||||
def to_json(self) -> str:
|
||||
data = self.to_dict()
|
||||
data = json.dumps(data)
|
||||
return data
|
||||
json_data = json.dumps(data)
|
||||
return json_data
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data: ConditionDict) -> None:
|
||||
def validate(cls, data: LingoEntry) -> None:
|
||||
try:
|
||||
_operator = data[cls._KEY] # underscore prefix to avoid name shadowing
|
||||
_operator = data["operator"]
|
||||
except KeyError:
|
||||
raise InvalidLogicalOperator(f"Invalid operator data: {data}")
|
||||
|
||||
|
@ -168,7 +172,7 @@ class ConditionLingo:
|
|||
the Lit Protocol (https://github.com/LIT-Protocol); credit to the authors for inspiring this work.
|
||||
"""
|
||||
|
||||
def __init__(self, conditions: List[Union[ReencryptionCondition, Operator, Any]]):
|
||||
def __init__(self, conditions: List[LingoEntryObject]):
|
||||
"""
|
||||
The input list *must* be structured as follows:
|
||||
condition
|
||||
|
@ -202,7 +206,7 @@ class ConditionLingo:
|
|||
instance = cls(conditions=conditions)
|
||||
return instance
|
||||
|
||||
def to_list(self): # TODO: __iter__ ?
|
||||
def to_list(self) -> LingoList: # TODO: __iter__ ?
|
||||
payload = [c.to_dict() for c in self.conditions]
|
||||
return payload
|
||||
|
||||
|
@ -212,8 +216,8 @@ class ConditionLingo:
|
|||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> 'ConditionLingo':
|
||||
data = json.loads(data)
|
||||
instance = cls.from_list(payload=data)
|
||||
payload = json.loads(data)
|
||||
instance = cls.from_list(payload=payload)
|
||||
return instance
|
||||
|
||||
def to_base64(self) -> bytes:
|
||||
|
@ -222,8 +226,8 @@ class ConditionLingo:
|
|||
|
||||
@classmethod
|
||||
def from_base64(cls, data: bytes) -> 'ConditionLingo':
|
||||
data = base64.b64decode(data).decode()
|
||||
instance = cls.from_json(data)
|
||||
decoded_json = base64.b64decode(data).decode()
|
||||
instance = cls.from_json(decoded_json)
|
||||
return instance
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
|
@ -260,5 +264,6 @@ class ConditionLingo:
|
|||
result = self.__eval(eval_string=eval_string)
|
||||
return result
|
||||
|
||||
|
||||
OR = Operator('or')
|
||||
AND = Operator('and')
|
||||
|
|
|
@ -1,6 +1,66 @@
|
|||
from typing import Dict, List, Union
|
||||
import sys
|
||||
|
||||
ConditionValues = Union[str, int, bool]
|
||||
ReturnValueDict = Dict[str, ConditionValues]
|
||||
ConditionDict = Dict[str, Union[ConditionValues, ReturnValueDict]]
|
||||
LingoList = List[ConditionDict]
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from web3.types import ABIFunction
|
||||
|
||||
#########
|
||||
# Context
|
||||
#########
|
||||
ContextDict = Dict[str, Any]
|
||||
|
||||
|
||||
################
|
||||
# ConditionLingo
|
||||
################
|
||||
|
||||
#
|
||||
# OperatorDict represents:
|
||||
# - {"operator": "and" | "or"}
|
||||
class OperatorDict(TypedDict):
|
||||
operator: str
|
||||
|
||||
|
||||
#
|
||||
# ConditionDict is a dictionary of:
|
||||
# - str -> Simple values (str, int, bool), or parameter list which can be anything ('Any')
|
||||
# - "returnValueTest" -> Return Value Test definitions
|
||||
# - "functionAbi" -> ABI function definitions (already defined by web3)
|
||||
#
|
||||
BaseValue = Union[str, int, bool]
|
||||
MethodParameters = List[Any]
|
||||
|
||||
ConditionValue = Union[BaseValue, MethodParameters] # base value or list of base values
|
||||
|
||||
|
||||
# Return Value Test
|
||||
class ReturnValueTestDict(TypedDict, total=False):
|
||||
comparator: str
|
||||
value: Any
|
||||
key: Union[str, int]
|
||||
|
||||
|
||||
ConditionDict = Dict[str, Union[ConditionValue, ReturnValueTestDict, ABIFunction]]
|
||||
|
||||
#
|
||||
# LingoEntry is:
|
||||
# - Condition
|
||||
# - Operator
|
||||
#
|
||||
LingoEntry = Union[OperatorDict, ConditionDict]
|
||||
|
||||
#
|
||||
# LingoList contains a list of LingoEntries
|
||||
LingoList = List[LingoEntry]
|
||||
|
||||
|
||||
#
|
||||
# Object Types
|
||||
#
|
||||
LingoEntryObjectType = Union[Type["Operator"], Type["ReencryptionCondition"]]
|
||||
LingoEntryObject = Union["Operator", "ReencryptionCondition"]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import json
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, NamedTuple, Optional, Type, Union
|
||||
from typing import Dict, NamedTuple, Optional, Tuple
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.base import ReencryptionCondition
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ConditionEvaluationFailed,
|
||||
ContextVariableVerificationFailed,
|
||||
|
@ -17,7 +16,13 @@ from nucypher.policy.conditions.exceptions import (
|
|||
RequiredContextVariable,
|
||||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.policy.conditions.types import ConditionDict, LingoList
|
||||
from nucypher.policy.conditions.types import (
|
||||
ContextDict,
|
||||
LingoEntry,
|
||||
LingoEntryObject,
|
||||
LingoEntryObjectType,
|
||||
LingoList,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
_ETH = "eth_"
|
||||
|
@ -44,7 +49,7 @@ class CamelCaseSchema(Schema):
|
|||
and snake-case for its internal representation.
|
||||
"""
|
||||
|
||||
SKIP_VALUES = tuple()
|
||||
SKIP_VALUES: Tuple = tuple()
|
||||
|
||||
def on_bind_field(self, field_name, field_obj):
|
||||
field_obj.data_key = to_camelcase(field_obj.data_key or field_name)
|
||||
|
@ -57,8 +62,8 @@ class CamelCaseSchema(Schema):
|
|||
|
||||
|
||||
def resolve_condition_lingo(
|
||||
data: ConditionDict,
|
||||
) -> Union[Type["Operator"], Type["ReencryptionCondition"]]:
|
||||
data: LingoEntry,
|
||||
) -> LingoEntryObjectType:
|
||||
"""
|
||||
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
|
||||
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
|
||||
|
@ -90,9 +95,7 @@ def resolve_condition_lingo(
|
|||
)
|
||||
|
||||
|
||||
def deserialize_condition_lingo(
|
||||
data: Union[str, ConditionDict]
|
||||
) -> Union["Operator", "ReencryptionCondition"]:
|
||||
def deserialize_condition_lingo(data: LingoEntry) -> LingoEntryObject:
|
||||
"""Deserialization helper for condition lingo"""
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
|
@ -110,7 +113,7 @@ def validate_condition_lingo(conditions: LingoList) -> None:
|
|||
def evaluate_condition_lingo(
|
||||
lingo: "ConditionLingo",
|
||||
providers: Optional[Dict[int, BaseProvider]] = None,
|
||||
context: Optional[Dict[Union[str, int], Union[str, int]]] = None,
|
||||
context: Optional[ContextDict] = None,
|
||||
log: Logger = __LOGGER,
|
||||
) -> Optional[EvalError]:
|
||||
"""
|
||||
|
@ -131,49 +134,53 @@ def evaluate_condition_lingo(
|
|||
result = lingo.eval(providers=providers, **context)
|
||||
if not result:
|
||||
# explicit condition failure
|
||||
error = ("Decryption conditions not satisfied", HTTPStatus.FORBIDDEN)
|
||||
error = EvalError(
|
||||
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
|
||||
)
|
||||
except ReturnValueEvaluationError as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Unable to evaluate return value: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except InvalidConditionLingo as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Invalid condition grammar: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except InvalidCondition as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Incorrect value provided for condition: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except RequiredContextVariable as e:
|
||||
# TODO: be more specific and name the missing inputs, etc
|
||||
error = (f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
|
||||
error = EvalError(f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
|
||||
except InvalidContextVariableData as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Invalid data provided for context variable: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except ContextVariableVerificationFailed as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Context variable data could not be verified: {e}",
|
||||
HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
except NoConnectionToChain as e:
|
||||
error = (
|
||||
error = EvalError(
|
||||
f"Node does not have a connection to chain ID {e.chain}: {e}",
|
||||
HTTPStatus.NOT_IMPLEMENTED,
|
||||
)
|
||||
except ConditionEvaluationFailed as e:
|
||||
error = (f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST)
|
||||
error = EvalError(
|
||||
f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: Unsure why we ended up here
|
||||
message = (
|
||||
f"Unexpected exception while evaluating "
|
||||
f"decryption condition ({e.__class__.__name__}): {e}"
|
||||
)
|
||||
error = (message, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
error = EvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
log.warn(message)
|
||||
|
||||
if error:
|
||||
|
|
|
@ -1,15 +1,10 @@
|
|||
|
||||
|
||||
|
||||
from typing import Dict, Set
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_canonical_address
|
||||
from nucypher_core import Address, Conditions, MessageKit, RetrievalKit
|
||||
from nucypher_core.umbral import PublicKey, SecretKey, VerifiedCapsuleFrag
|
||||
|
||||
from nucypher_core import Address, MessageKit, RetrievalKit
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
|
||||
|
||||
class PolicyMessageKit:
|
||||
|
||||
|
@ -54,7 +49,7 @@ class PolicyMessageKit:
|
|||
message_kit=self.message_kit)
|
||||
|
||||
@property
|
||||
def conditions(self) -> ConditionLingo:
|
||||
def conditions(self) -> Conditions:
|
||||
return self.message_kit.conditions
|
||||
|
||||
|
||||
|
@ -71,7 +66,7 @@ class RetrievalResult:
|
|||
def __init__(self, cfrags: Dict[ChecksumAddress, VerifiedCapsuleFrag]):
|
||||
self.cfrags = cfrags
|
||||
|
||||
def canonical_addresses(self) -> Set[bytes]:
|
||||
def canonical_addresses(self) -> Set[Address]:
|
||||
# TODO (#1995): propagate this to use canonical addresses everywhere
|
||||
return set([Address(to_canonical_address(address)) for address in self.cfrags])
|
||||
|
||||
|
|
Loading…
Reference in New Issue