Clean up inheritance and proper validation of conditions.

pull/3500/head
derekpierre 2024-09-16 11:02:40 -04:00
parent 9499039d94
commit 468aae93dc
No known key found for this signature in database
5 changed files with 102 additions and 138 deletions

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from base64 import b64decode, b64encode
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
from marshmallow import Schema, ValidationError, fields
@ -52,14 +52,23 @@ class _Serializable:
class AccessControlCondition(_Serializable, ABC):
CONDITION_TYPE = NotImplemented
class Schema(CamelCaseSchema):
name = NotImplemented
def __init__(self):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
condition_type = NotImplemented
def __init__(self, condition_type: str, name: Optional[str] = None):
super().__init__()
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.name = name
# validate inputs using marshmallow schema
schema = self.Schema()
errors = schema.validate(self.to_dict())
@ -93,10 +102,9 @@ class AccessControlCondition(_Serializable, ABC):
class ExecutionCall(_Serializable, ABC):
CALL_TYPE = NotImplemented
class Schema(CamelCaseSchema):
call_type = fields.Str(required=True)
SKIP_VALUES = (None,)
pass
@abstractmethod
def execute(self, *args, **kwargs) -> Any:

View File

@ -19,7 +19,6 @@ from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionCall,
)
from nucypher.policy.conditions.context import (
@ -32,7 +31,11 @@ from nucypher.policy.conditions.exceptions import (
RequiredContextVariable,
RPCExecutionFailed,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.lingo import (
BaseAccessControlCondition,
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.policy.conditions.validation import (
_align_comparator_value_with_abi,
@ -106,8 +109,6 @@ def _validate_chain(chain: int) -> None:
class RPCCall(ExecutionCall):
CALL_TYPE = "rpc"
LOG = logging.Logger(__name__)
ALLOWED_METHODS = {
@ -116,8 +117,6 @@ class RPCCall(ExecutionCall):
} # TODO other allowed methods (tDEC #64)
class Schema(ExecutionCall.Schema):
SKIP_VALUES = (None,)
call_type = fields.Str(validate=validate.Equal("rpc"), required=True)
chain = fields.Int(
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
)
@ -132,19 +131,11 @@ class RPCCall(ExecutionCall):
self,
chain: int,
method: str,
call_type: str = CALL_TYPE,
parameters: Optional[List[Any]] = None,
):
# Validate input
if call_type != self.CALL_TYPE:
raise ValueError(
f"{self.__class__.__name__} must be instantiated with the '{self.CALL_TYPE}' type; '{call_type}' is invalid"
)
_validate_chain(chain=chain)
self.call_type = call_type
self.chain = chain
self.method = self._validate_method(method=method)
self.parameters = parameters or None
@ -245,20 +236,13 @@ class RPCCall(ExecutionCall):
return rpc_result
class RPCCondition(AccessControlCondition):
class RPCCondition(BaseAccessControlCondition):
CONDITION_TYPE = ConditionType.RPC.value
class Schema(RPCCall.Schema):
name = fields.Str(required=False)
class Schema(BaseAccessControlCondition.Schema, RPCCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
class Meta:
exclude = ("call_type",) # don't serialize call_type
@post_load
def make(self, data, **kwargs):
@ -270,43 +254,28 @@ class RPCCondition(AccessControlCondition):
def __init__(
self,
return_value_test: ReturnValueTest,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
*args,
**kwargs,
):
# internal
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
super().__init__(condition_type=condition_type, *args, **kwargs)
try:
self.rpc_call = self._create_rpc_call(*args, **kwargs)
except ValueError as e:
raise InvalidCondition(str(e))
self.name = name
self.condition_type = condition_type
self.return_value_test = return_value_test # output
self._validate_expected_return_type()
def _create_rpc_call(self, *args, **kwargs):
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return RPCCall(*args, **kwargs)
@property
def method(self):
return self.rpc_call.method
return self.execution_call.method
@property
def chain(self):
return self.rpc_call.chain
return self.execution_call.chain
@property
def parameters(self):
return self.rpc_call.parameters
return self.execution_call.parameters
def _validate_expected_return_type(self):
expected_return_type = RPCCall.ALLOWED_METHODS[self.method]
@ -328,27 +297,20 @@ class RPCCondition(AccessControlCondition):
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
) -> Tuple[bool, Any]:
"""
Verifies the onchain condition is met by performing a
read operation and evaluating the return value test.
"""
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
)
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test
)
result = self.rpc_call.execute(providers=providers, **context)
result = self.execution_call.execute(providers=providers, **context)
eval_result = return_value_test.eval(result) # test
return eval_result, result
class ContractCall(RPCCall):
CALL_TYPE = "contract"
class Schema(RPCCall.Schema):
call_type = fields.Str(validate=validate.Equal("contract"), required=True)
contract_address = fields.Str(required=True)
standard_contract_type = fields.Str(required=False)
function_abi = fields.Dict(required=False)
@ -372,7 +334,6 @@ class ContractCall(RPCCall):
self,
method: str,
contract_address: ChecksumAddress,
call_type: str = CALL_TYPE,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
@ -391,7 +352,7 @@ class ContractCall(RPCCall):
self.standard_contract_type = standard_contract_type
self.function_abi = function_abi
super().__init__(method=method, call_type=call_type, *args, **kwargs)
super().__init__(method=method, *args, **kwargs)
self.contract_function = self._get_unbound_contract_function()
def _validate_method(self, method):
@ -435,9 +396,6 @@ class ContractCondition(RPCCondition):
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
)
class Meta:
exclude = ("call_type",) # don't serialize call_type
@post_load
def make(self, data, **kwargs):
return ContractCondition(**data)
@ -451,24 +409,24 @@ class ContractCondition(RPCCondition):
# call to super must be at the end for proper validation
super().__init__(condition_type=condition_type, *args, **kwargs)
def _create_rpc_call(self, *args, **kwargs) -> ContractCall:
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return ContractCall(*args, **kwargs)
@property
def function_abi(self):
return self.rpc_call.function_abi
return self.execution_call.function_abi
@property
def standard_contract_type(self):
return self.rpc_call.standard_contract_type
return self.execution_call.standard_contract_type
@property
def contract_function(self):
return self.rpc_call.contract_function
return self.execution_call.contract_function
@property
def contract_address(self):
return self.rpc_call.contract_address
return self.execution_call.contract_address
def _validate_expected_return_type(self) -> None:
_validate_contract_function_expected_return_type(

View File

@ -2,6 +2,7 @@ import ast
import base64
import json
import operator as pyoperator
from abc import abstractmethod
from enum import Enum
from hashlib import md5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
@ -23,6 +24,7 @@ from web3 import HTTPProvider
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionCall,
_Serializable,
)
from nucypher.policy.conditions.context import (
@ -126,12 +128,10 @@ class CompoundAccessControlCondition(AccessControlCondition):
# TODO nested operands
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
class Schema(AccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False)
operator = fields.Str(required=True)
operands = fields.List(_ConditionField, required=True)
@ -164,11 +164,6 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*]
}
"""
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_operator_and_operands(operator, operands, InvalidCondition)
self.operator = operator
@ -177,6 +172,8 @@ class CompoundAccessControlCondition(AccessControlCondition):
self.name = name
self.id = md5(bytes(self)).hexdigest()[:6]
super().__init__(condition_type=condition_type, name=name)
def __repr__(self):
return f"Operator={self.operator} (NumOperands={len(self.operands)}), id={self.id})"
@ -274,9 +271,7 @@ class SequentialAccessControlCondition(AccessControlCondition):
f"Maximum of {cls.MAX_NUM_CONDITION_VARIABLES} condition variables are allowed"
)
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
class Schema(AccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
)
@ -305,17 +300,12 @@ class SequentialAccessControlCondition(AccessControlCondition):
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_condition_variables(
condition_variables=condition_variables, exception_class=InvalidCondition
)
self.name = name
self.condition_variables = condition_variables
self.condition_type = condition_type
super().__init__(condition_type=condition_type, name=name)
def __repr__(self):
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
@ -575,3 +565,46 @@ class ConditionLingo(_Serializable):
raise InvalidConditionLingo(
f"Version provided, {version}, is incompatible with current version {cls.VERSION}"
)
class BaseAccessControlCondition(AccessControlCondition):
class Schema(AccessControlCondition.Schema):
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
def __init__(
self,
condition_type: str,
return_value_test: ReturnValueTest,
name: Optional[str] = None,
*args,
**kwargs,
):
self.return_value_test = return_value_test
try:
self.execution_call = self._create_execution_call(*args, **kwargs)
except ValueError as e:
raise InvalidCondition(str(e))
super().__init__(condition_type=condition_type, name=name)
@abstractmethod
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
"""
Returns the execution call that the condition executes.
"""
raise NotImplementedError
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
"""
Verifies the condition is met by performing execution call and
evaluating the return value test.
"""
result = self.execution_call.execute(*args, **kwargs)
resolved_return_value_test = self.return_value_test.with_resolved_context(
**kwargs
)
eval_result = resolved_return_value_test.eval(result) # test
return eval_result, result

View File

@ -6,12 +6,15 @@ from jsonpath_ng.ext import parse
from marshmallow import fields, post_load, validate
from marshmallow.fields import Field, Url
from nucypher.policy.conditions.base import AccessControlCondition, ExecutionCall
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
InvalidCondition,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.lingo import (
BaseAccessControlCondition,
ConditionType,
)
from nucypher.utilities.logging import Logger
@ -32,13 +35,9 @@ class JSONPathField(Field):
class JsonApiCall(ExecutionCall):
CALL_TYPE = "json-api"
TIMEOUT = 5 # seconds
class Schema(ExecutionCall.Schema):
SKIP_VALUES = (None,)
call_type = fields.Str(validate=validate.Equal("json-api"), required=True)
endpoint = Url(required=True, relative=False, schemes=["https"])
parameters = fields.Dict(required=False, allow_none=True)
query = JSONPathField(required=False, allow_none=True)
@ -50,12 +49,9 @@ class JsonApiCall(ExecutionCall):
def __init__(
self,
endpoint: str,
call_type: str = CALL_TYPE,
parameters: Optional[dict] = None,
query: Optional[str] = None,
):
self.call_type = call_type
self.endpoint = endpoint
self.parameters = parameters or {}
self.query = query
@ -135,7 +131,7 @@ class JsonApiCall(ExecutionCall):
return result
class JsonApiCondition(AccessControlCondition):
class JsonApiCondition(BaseAccessControlCondition):
"""
A JSON API condition is a condition that can be evaluated by reading from a JSON
HTTPS endpoint. The response must return an HTTP 200 with valid JSON in the response body.
@ -144,17 +140,10 @@ class JsonApiCondition(AccessControlCondition):
CONDITION_TYPE = ConditionType.JSONAPI.value
class Schema(JsonApiCall.Schema):
name = fields.Str(required=False)
class Schema(BaseAccessControlCondition.Schema, JsonApiCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
class Meta:
exclude = ("call_type",) # don't serialize call_type
@post_load
def make(self, data, **kwargs):
@ -162,47 +151,30 @@ class JsonApiCondition(AccessControlCondition):
def __init__(
self,
return_value_test: ReturnValueTest,
condition_type: str = ConditionType.JSONAPI.value,
name: Optional[str] = None,
*args,
**kwargs,
):
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
super().__init__(condition_type=condition_type, *args, **kwargs)
try:
self.json_api_call = self._create_json_api_call(*args, **kwargs)
except ValueError as e:
raise InvalidCondition(str(e))
self.name = name
self.condition_type = condition_type
self.return_value_test = return_value_test
super().__init__()
def _create_json_api_call(self, *args, **kwargs):
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return JsonApiCall(*args, **kwargs)
@property
def endpoint(self):
return self.json_api_call.endpoint
return self.execution_call.endpoint
@property
def query(self):
return self.json_api_call.query
return self.execution_call.query
@property
def parameters(self):
return self.json_api_call.parameters
return self.execution_call.parameters
@property
def timeout(self):
return self.json_api_call.timeout
return self.execution_call.timeout
@staticmethod
def _process_result_for_eval(result: Any):
@ -226,7 +198,7 @@ class JsonApiCondition(AccessControlCondition):
and evaluating the return value test with the result. Parses the endpoint's JSON response using
JSONPath.
"""
result = self.json_api_call.execute(**context)
result = self.execution_call.execute(**context)
result_for_eval = self._process_result_for_eval(result)
resolved_return_value_test = self.return_value_test.with_resolved_context(

View File

@ -4,17 +4,16 @@ from marshmallow import fields, post_load, validate
from marshmallow.validate import Equal
from web3 import Web3
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ConditionType
class TimeRPCCall(RPCCall):
CALL_TYPE = "time"
METHOD = "blocktime"
class Schema(RPCCall.Schema):
call_type = fields.Str(validate=validate.Equal("time"), required=True)
method = fields.Str(
dump_default="blocktime", required=True, validate=Equal("blocktime")
)
@ -27,15 +26,12 @@ class TimeRPCCall(RPCCall):
self,
chain: int,
method: str = METHOD,
call_type: str = CALL_TYPE,
parameters: Optional[List[Any]] = None,
):
if parameters:
raise ValueError(f"{self.METHOD} does not take any parameters")
super().__init__(
chain=chain, method=method, parameters=parameters, call_type=call_type
)
super().__init__(chain=chain, method=method, parameters=parameters)
def _validate_method(self, method):
if method != self.METHOD:
@ -59,9 +55,6 @@ class TimeCondition(RPCCondition):
validate=validate.Equal(ConditionType.TIME.value), required=True
)
class Meta:
exclude = ("call_type",) # don't serialize call_type
@post_load
def make(self, data, **kwargs):
return TimeCondition(**data)
@ -85,7 +78,7 @@ class TimeCondition(RPCCondition):
**kwargs,
)
def _create_rpc_call(self, *args, **kwargs):
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return TimeRPCCall(*args, **kwargs)
def _validate_expected_return_type(self):