Define types for serialization/deserialization of calls and variables.

Fix call schemas; some missed a post_load processing method for creating the correct type.
pull/3500/head
derekpierre 2024-05-10 13:35:15 -04:00
parent 4e50fc1bef
commit 1f85461d28
No known key found for this signature in database
5 changed files with 121 additions and 33 deletions

View File

@ -93,19 +93,11 @@ class AccessControlCondition(_Serializable, ABC):
class ExecutionCall(_Serializable, ABC):
CALL_TYPE = NotImplemented
class Schema(CamelCaseSchema):
call_type = fields.Str(required=True)
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
raise NotImplementedError
class ExecutionVariable(_Serializable, ABC):
class Schema(CamelCaseSchema):
var_name = fields.Str(required=True)
call = NotImplemented
def __init__(self, var_name: str, call: ExecutionCall):
self.var_name = var_name
self.call = call

View File

@ -106,6 +106,8 @@ def _validate_chain(chain: int) -> None:
class RPCCall(ExecutionCall):
CALL_TYPE = "rpc"
LOG = logging.Logger(__name__)
ALLOWED_METHODS = {
@ -115,6 +117,7 @@ class RPCCall(ExecutionCall):
class Schema(ExecutionCall.Schema):
SKIP_VALUES = (None,)
call_type = fields.Str(validate=validate.Equal("rpc"), required=True)
chain = fields.Int(
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
)
@ -129,12 +132,15 @@ class RPCCall(ExecutionCall):
self,
chain: int,
method: str,
call_type: str = "rpc",
call_type: str = CALL_TYPE,
parameters: Optional[List[Any]] = None,
):
# Validate input
self._validate_call_type(call_type)
# TODO: Additional validation (function is valid for ABI, RVT validity, standard contract name validity, etc.)
if call_type != self.CALL_TYPE:
raise ValueError(
f"{self.__class__.__name__} must be instantiated with the '{self.CALL_TYPE}' type; '{call_type}' is invalid"
)
_validate_chain(chain=chain)
self.call_type = call_type
@ -143,10 +149,6 @@ class RPCCall(ExecutionCall):
self.method = self._validate_method(method=method)
self.parameters = parameters or None
def _validate_call_type(self, call_type):
if call_type != "rpc":
raise ValueError(f"Invalid execution call type: {call_type}")
def _validate_method(self, method):
if not method:
raise ValueError("Undefined method name")
@ -343,7 +345,10 @@ class RPCCondition(AccessControlCondition):
class ContractCall(RPCCall):
CALL_TYPE = "contract"
class Schema(RPCCall.Schema):
call_type = fields.Str(validate=validate.Equal("contract"), required=True)
contract_address = fields.Str(required=True)
standard_contract_type = fields.Str(required=False)
function_abi = fields.Dict(required=False)
@ -367,7 +372,7 @@ class ContractCall(RPCCall):
self,
method: str,
contract_address: ChecksumAddress,
call_type: str = "contract",
call_type: str = CALL_TYPE,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
@ -389,10 +394,6 @@ class ContractCall(RPCCall):
super().__init__(method=method, call_type=call_type, *args, **kwargs)
self.contract_function = self._get_unbound_contract_function()
def _validate_call_type(self, call_type):
if call_type != "contract":
raise ValueError(f"Invalid execution call type: {call_type}")
def _validate_method(self, method):
return method

View File

@ -23,7 +23,7 @@ from web3 import HTTPProvider
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionVariable,
ExecutionCall,
_Serializable,
)
from nucypher.policy.conditions.context import (
@ -35,7 +35,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, Lingo
from nucypher.policy.conditions.types import ConditionDict, ExecutionCallDict, Lingo
from nucypher.policy.conditions.utils import CamelCaseSchema
@ -57,6 +57,7 @@ class _ConditionField(fields.Dict):
instance = condition_class.from_dict(condition_data)
return instance
#
# CONDITION = BASE_CONDITION | COMPOUND_CONDITION
#
@ -248,6 +249,56 @@ _COMPARATOR_FUNCTIONS = {
# }
class _ExecutionCallField(fields.Dict):
"""Serializes/Deserializes Conditions to/from dictionaries"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return value.to_dict()
def _deserialize(self, value, attr, data, **kwargs):
execution_call_dict = value
execution_call_class = self.resolve_execution_call_class(
execution_call_dict=execution_call_dict
)
instance = execution_call_class.from_dict(execution_call_dict)
return instance
@classmethod
def resolve_execution_call_class(cls, execution_call_dict: ExecutionCallDict):
from nucypher.policy.conditions.evm import (
ContractCall,
RPCCall,
)
from nucypher.policy.conditions.time import TimeRPCCall
call_type = execution_call_dict.get("callType")
for execution_call_type in (
RPCCall,
TimeRPCCall,
ContractCall,
):
if execution_call_type.CALL_TYPE == call_type:
return execution_call_type
raise InvalidConditionLingo(
f"Cannot resolve condition lingo with call type {call_type}"
)
class ExecutionVariable(_Serializable):
class Schema(CamelCaseSchema):
var_name = fields.Str(required=True)
call = _ExecutionCallField(required=True)
def __init__(self, var_name: str, call: ExecutionCall):
self.var_name = var_name
self.call = call
class SequentialAccessControlCondition(AccessControlCondition):
CONDITION_TYPE = ConditionType.SEQUENTIAL.value
MAX_NUM_VARIABLES = 5
@ -272,7 +323,9 @@ class SequentialAccessControlCondition(AccessControlCondition):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
)
variables = fields.List(fields.Str) # TODO placeholder; fixme
variables = fields.List(
fields.Nested(ExecutionVariable.Schema(), required=True)
)
name = fields.Str(required=False)
condition = _ConditionField(required=True)
@ -532,7 +585,6 @@ class ConditionLingo(_Serializable):
cls, condition: ConditionDict, version: int = None
) -> Type[AccessControlCondition]:
"""
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
conditions expression framework.
"""

View File

@ -10,18 +10,24 @@ from nucypher.policy.conditions.lingo import ConditionType
class TimeRPCCall(RPCCall):
CALL_TYPE = "time"
METHOD = "blocktime"
class Schema(RPCCall.Schema):
call_type = fields.Str(validate=validate.Equal("time"), required=True)
method = fields.Str(
dump_default="blocktime", required=True, validate=Equal("blocktime")
)
@post_load
def make(self, data, **kwargs):
return TimeRPCCall(**data)
def __init__(
self,
chain: int,
method: str = METHOD,
call_type: str = "time",
call_type: str = CALL_TYPE,
parameters: Optional[List[Any]] = None,
):
if parameters:
@ -31,10 +37,6 @@ class TimeRPCCall(RPCCall):
chain=chain, method=method, parameters=parameters, call_type=call_type
)
def _validate_call_type(self, call_type):
if call_type != "time":
raise ValueError(f"Invalid execution call type: {call_type}")
def _validate_method(self, method):
if method != self.METHOD:
raise ValueError(

View File

@ -34,6 +34,37 @@ class ReturnValueTestDict(TypedDict):
key: NotRequired[Union[str, int]]
# Calls
class BaseExecutionCallDict(TypedDict):
callType: str
class RPCCallDict(BaseExecutionCallDict):
chain: int
method: str
parameters: NotRequired[List[Any]]
class TimeRPCCallDict(RPCCallDict):
pass
class ContractCallDict(RPCCallDict):
contractAddress: str
standardContractType: NotRequired[str]
functionAbi: NotRequired[ABIFunction]
ExecutionCallDict = Union[RPCCallDict, TimeRPCCallDict, ContractCallDict]
# Variable
class ExecutionVariableDict(TypedDict):
varName: str
call: ExecutionCallDict
# Conditions
class _AccessControlCondition(TypedDict):
name: NotRequired[str]
@ -63,20 +94,30 @@ class ContractConditionDict(RPCConditionDict):
# "operands": List[AccessControlCondition | CompoundCondition]
#
#
class CompoundConditionDict(TypedDict):
class CompoundConditionDict(_AccessControlCondition):
conditionType: str
operator: Literal["and", "or"]
operands: List["Lingo"]
class SequentialConditionDict(_AccessControlCondition):
variables = List[ExecutionVariableDict]
condition: "Lingo"
#
# ConditionDict is a dictionary of:
# - TimeCondition
# - RPCCondition
# - ContractCondition
# - CompoundConditionDict
# - SequentialConditionDict
ConditionDict = Union[
TimeConditionDict, RPCConditionDict, ContractConditionDict, CompoundConditionDict
TimeConditionDict,
RPCConditionDict,
ContractConditionDict,
CompoundConditionDict,
SequentialConditionDict,
]