mirror of https://github.com/nucypher/nucypher.git
Define types for serialization/deserialization of calls and variables.
Fix call schemas; some missed a post_load processing method for creating the correct type.pull/3500/head
parent
4e50fc1bef
commit
1f85461d28
|
@ -93,19 +93,11 @@ class AccessControlCondition(_Serializable, ABC):
|
|||
|
||||
|
||||
class ExecutionCall(_Serializable, ABC):
|
||||
CALL_TYPE = NotImplemented
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
call_type = fields.Str(required=True)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ExecutionVariable(_Serializable, ABC):
|
||||
class Schema(CamelCaseSchema):
|
||||
var_name = fields.Str(required=True)
|
||||
call = NotImplemented
|
||||
|
||||
def __init__(self, var_name: str, call: ExecutionCall):
|
||||
self.var_name = var_name
|
||||
self.call = call
|
||||
|
|
|
@ -106,6 +106,8 @@ def _validate_chain(chain: int) -> None:
|
|||
|
||||
|
||||
class RPCCall(ExecutionCall):
|
||||
CALL_TYPE = "rpc"
|
||||
|
||||
LOG = logging.Logger(__name__)
|
||||
|
||||
ALLOWED_METHODS = {
|
||||
|
@ -115,6 +117,7 @@ class RPCCall(ExecutionCall):
|
|||
|
||||
class Schema(ExecutionCall.Schema):
|
||||
SKIP_VALUES = (None,)
|
||||
call_type = fields.Str(validate=validate.Equal("rpc"), required=True)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
|
@ -129,12 +132,15 @@ class RPCCall(ExecutionCall):
|
|||
self,
|
||||
chain: int,
|
||||
method: str,
|
||||
call_type: str = "rpc",
|
||||
call_type: str = CALL_TYPE,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
# Validate input
|
||||
self._validate_call_type(call_type)
|
||||
# TODO: Additional validation (function is valid for ABI, RVT validity, standard contract name validity, etc.)
|
||||
if call_type != self.CALL_TYPE:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} must be instantiated with the '{self.CALL_TYPE}' type; '{call_type}' is invalid"
|
||||
)
|
||||
|
||||
_validate_chain(chain=chain)
|
||||
|
||||
self.call_type = call_type
|
||||
|
@ -143,10 +149,6 @@ class RPCCall(ExecutionCall):
|
|||
self.method = self._validate_method(method=method)
|
||||
self.parameters = parameters or None
|
||||
|
||||
def _validate_call_type(self, call_type):
|
||||
if call_type != "rpc":
|
||||
raise ValueError(f"Invalid execution call type: {call_type}")
|
||||
|
||||
def _validate_method(self, method):
|
||||
if not method:
|
||||
raise ValueError("Undefined method name")
|
||||
|
@ -343,7 +345,10 @@ class RPCCondition(AccessControlCondition):
|
|||
|
||||
|
||||
class ContractCall(RPCCall):
|
||||
CALL_TYPE = "contract"
|
||||
|
||||
class Schema(RPCCall.Schema):
|
||||
call_type = fields.Str(validate=validate.Equal("contract"), required=True)
|
||||
contract_address = fields.Str(required=True)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
function_abi = fields.Dict(required=False)
|
||||
|
@ -367,7 +372,7 @@ class ContractCall(RPCCall):
|
|||
self,
|
||||
method: str,
|
||||
contract_address: ChecksumAddress,
|
||||
call_type: str = "contract",
|
||||
call_type: str = CALL_TYPE,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
|
@ -389,10 +394,6 @@ class ContractCall(RPCCall):
|
|||
super().__init__(method=method, call_type=call_type, *args, **kwargs)
|
||||
self.contract_function = self._get_unbound_contract_function()
|
||||
|
||||
def _validate_call_type(self, call_type):
|
||||
if call_type != "contract":
|
||||
raise ValueError(f"Invalid execution call type: {call_type}")
|
||||
|
||||
def _validate_method(self, method):
|
||||
return method
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from web3 import HTTPProvider
|
|||
|
||||
from nucypher.policy.conditions.base import (
|
||||
AccessControlCondition,
|
||||
ExecutionVariable,
|
||||
ExecutionCall,
|
||||
_Serializable,
|
||||
)
|
||||
from nucypher.policy.conditions.context import (
|
||||
|
@ -35,7 +35,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidConditionLingo,
|
||||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.policy.conditions.types import ConditionDict, Lingo
|
||||
from nucypher.policy.conditions.types import ConditionDict, ExecutionCallDict, Lingo
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
|
||||
|
||||
|
@ -57,6 +57,7 @@ class _ConditionField(fields.Dict):
|
|||
instance = condition_class.from_dict(condition_data)
|
||||
return instance
|
||||
|
||||
|
||||
#
|
||||
# CONDITION = BASE_CONDITION | COMPOUND_CONDITION
|
||||
#
|
||||
|
@ -248,6 +249,56 @@ _COMPARATOR_FUNCTIONS = {
|
|||
# }
|
||||
|
||||
|
||||
class _ExecutionCallField(fields.Dict):
|
||||
"""Serializes/Deserializes Conditions to/from dictionaries"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return value.to_dict()
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
execution_call_dict = value
|
||||
execution_call_class = self.resolve_execution_call_class(
|
||||
execution_call_dict=execution_call_dict
|
||||
)
|
||||
instance = execution_call_class.from_dict(execution_call_dict)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def resolve_execution_call_class(cls, execution_call_dict: ExecutionCallDict):
|
||||
from nucypher.policy.conditions.evm import (
|
||||
ContractCall,
|
||||
RPCCall,
|
||||
)
|
||||
from nucypher.policy.conditions.time import TimeRPCCall
|
||||
|
||||
call_type = execution_call_dict.get("callType")
|
||||
|
||||
for execution_call_type in (
|
||||
RPCCall,
|
||||
TimeRPCCall,
|
||||
ContractCall,
|
||||
):
|
||||
if execution_call_type.CALL_TYPE == call_type:
|
||||
return execution_call_type
|
||||
|
||||
raise InvalidConditionLingo(
|
||||
f"Cannot resolve condition lingo with call type {call_type}"
|
||||
)
|
||||
|
||||
|
||||
class ExecutionVariable(_Serializable):
|
||||
class Schema(CamelCaseSchema):
|
||||
var_name = fields.Str(required=True)
|
||||
call = _ExecutionCallField(required=True)
|
||||
|
||||
def __init__(self, var_name: str, call: ExecutionCall):
|
||||
self.var_name = var_name
|
||||
self.call = call
|
||||
|
||||
|
||||
class SequentialAccessControlCondition(AccessControlCondition):
|
||||
CONDITION_TYPE = ConditionType.SEQUENTIAL.value
|
||||
MAX_NUM_VARIABLES = 5
|
||||
|
@ -272,7 +323,9 @@ class SequentialAccessControlCondition(AccessControlCondition):
|
|||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
|
||||
)
|
||||
variables = fields.List(fields.Str) # TODO placeholder; fixme
|
||||
variables = fields.List(
|
||||
fields.Nested(ExecutionVariable.Schema(), required=True)
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
condition = _ConditionField(required=True)
|
||||
|
||||
|
@ -532,7 +585,6 @@ class ConditionLingo(_Serializable):
|
|||
cls, condition: ConditionDict, version: int = None
|
||||
) -> Type[AccessControlCondition]:
|
||||
"""
|
||||
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
|
||||
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
|
||||
conditions expression framework.
|
||||
"""
|
||||
|
|
|
@ -10,18 +10,24 @@ from nucypher.policy.conditions.lingo import ConditionType
|
|||
|
||||
|
||||
class TimeRPCCall(RPCCall):
|
||||
CALL_TYPE = "time"
|
||||
METHOD = "blocktime"
|
||||
|
||||
class Schema(RPCCall.Schema):
|
||||
call_type = fields.Str(validate=validate.Equal("time"), required=True)
|
||||
method = fields.Str(
|
||||
dump_default="blocktime", required=True, validate=Equal("blocktime")
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return TimeRPCCall(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: int,
|
||||
method: str = METHOD,
|
||||
call_type: str = "time",
|
||||
call_type: str = CALL_TYPE,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
if parameters:
|
||||
|
@ -31,10 +37,6 @@ class TimeRPCCall(RPCCall):
|
|||
chain=chain, method=method, parameters=parameters, call_type=call_type
|
||||
)
|
||||
|
||||
def _validate_call_type(self, call_type):
|
||||
if call_type != "time":
|
||||
raise ValueError(f"Invalid execution call type: {call_type}")
|
||||
|
||||
def _validate_method(self, method):
|
||||
if method != self.METHOD:
|
||||
raise ValueError(
|
||||
|
|
|
@ -34,6 +34,37 @@ class ReturnValueTestDict(TypedDict):
|
|||
key: NotRequired[Union[str, int]]
|
||||
|
||||
|
||||
# Calls
|
||||
class BaseExecutionCallDict(TypedDict):
|
||||
callType: str
|
||||
|
||||
|
||||
class RPCCallDict(BaseExecutionCallDict):
|
||||
chain: int
|
||||
method: str
|
||||
parameters: NotRequired[List[Any]]
|
||||
|
||||
|
||||
class TimeRPCCallDict(RPCCallDict):
|
||||
pass
|
||||
|
||||
|
||||
class ContractCallDict(RPCCallDict):
|
||||
contractAddress: str
|
||||
standardContractType: NotRequired[str]
|
||||
functionAbi: NotRequired[ABIFunction]
|
||||
|
||||
|
||||
ExecutionCallDict = Union[RPCCallDict, TimeRPCCallDict, ContractCallDict]
|
||||
|
||||
|
||||
# Variable
|
||||
class ExecutionVariableDict(TypedDict):
|
||||
varName: str
|
||||
call: ExecutionCallDict
|
||||
|
||||
|
||||
# Conditions
|
||||
class _AccessControlCondition(TypedDict):
|
||||
name: NotRequired[str]
|
||||
|
||||
|
@ -63,20 +94,30 @@ class ContractConditionDict(RPCConditionDict):
|
|||
# "operands": List[AccessControlCondition | CompoundCondition]
|
||||
#
|
||||
#
|
||||
class CompoundConditionDict(TypedDict):
|
||||
class CompoundConditionDict(_AccessControlCondition):
|
||||
conditionType: str
|
||||
operator: Literal["and", "or"]
|
||||
operands: List["Lingo"]
|
||||
|
||||
|
||||
class SequentialConditionDict(_AccessControlCondition):
|
||||
variables = List[ExecutionVariableDict]
|
||||
condition: "Lingo"
|
||||
|
||||
|
||||
#
|
||||
# ConditionDict is a dictionary of:
|
||||
# - TimeCondition
|
||||
# - RPCCondition
|
||||
# - ContractCondition
|
||||
# - CompoundConditionDict
|
||||
# - SequentialConditionDict
|
||||
ConditionDict = Union[
|
||||
TimeConditionDict, RPCConditionDict, ContractConditionDict, CompoundConditionDict
|
||||
TimeConditionDict,
|
||||
RPCConditionDict,
|
||||
ContractConditionDict,
|
||||
CompoundConditionDict,
|
||||
SequentialConditionDict,
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue