mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3556 from derekpierre/cleanup-validation
Clean up/Simplify condition validationpull/3563/head
commit
34842fd88c
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from marshmallow import Schema, ValidationError, fields
|
||||
|
||||
|
@ -9,7 +9,10 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
from nucypher.policy.conditions.utils import (
|
||||
CamelCaseSchema,
|
||||
extract_single_error_message_from_schema_errors,
|
||||
)
|
||||
|
||||
|
||||
class _Serializable:
|
||||
|
@ -56,62 +59,49 @@ class AccessControlCondition(_Serializable, ABC):
|
|||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
name = fields.Str(required=False)
|
||||
name = fields.Str(required=False, allow_none=True)
|
||||
condition_type = NotImplemented
|
||||
|
||||
def __init__(self, condition_type: str, name: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
|
||||
# validate inputs using marshmallow schema
|
||||
self.validate(self.to_dict())
|
||||
self._validate()
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
|
||||
"""Returns the boolean result of the evaluation and the returned value in a two-tuple."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data: Dict) -> None:
|
||||
errors = cls.Schema().validate(data=data)
|
||||
def _validate(self, **kwargs):
|
||||
errors = self.Schema().validate(data=self.to_dict())
|
||||
if errors:
|
||||
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")
|
||||
error_message = extract_single_error_message_from_schema_errors(errors)
|
||||
raise InvalidCondition(
|
||||
f"Invalid {self.__class__.__name__}: {error_message}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data) -> "AccessControlCondition":
|
||||
try:
|
||||
return super().from_dict(data)
|
||||
except ValidationError as e:
|
||||
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
|
||||
raise InvalidConditionLingo(f"Invalid condition grammar: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data) -> "AccessControlCondition":
|
||||
try:
|
||||
return super().from_json(data)
|
||||
except ValidationError as e:
|
||||
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
|
||||
|
||||
|
||||
class ExecutionCall(ABC):
|
||||
@abstractmethod
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
raise InvalidConditionLingo(f"Invalid condition grammar: {e}") from e
|
||||
|
||||
|
||||
class MultiConditionAccessControl(AccessControlCondition):
|
||||
MAX_NUM_CONDITIONS = 5
|
||||
MAX_MULTI_CONDITION_NESTED_LEVEL = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._validate_multi_condition_nesting(conditions=self.conditions)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def conditions(self) -> List[AccessControlCondition]:
|
||||
|
@ -121,11 +111,13 @@ class MultiConditionAccessControl(AccessControlCondition):
|
|||
def _validate_multi_condition_nesting(
|
||||
cls,
|
||||
conditions: List[AccessControlCondition],
|
||||
field_name: str,
|
||||
current_level: int = 1,
|
||||
):
|
||||
if len(conditions) > cls.MAX_NUM_CONDITIONS:
|
||||
raise InvalidCondition(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
|
||||
raise ValidationError(
|
||||
field_name=field_name,
|
||||
message=f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed",
|
||||
)
|
||||
|
||||
for condition in conditions:
|
||||
|
@ -134,10 +126,31 @@ class MultiConditionAccessControl(AccessControlCondition):
|
|||
|
||||
level = current_level + 1
|
||||
if level > cls.MAX_MULTI_CONDITION_NESTED_LEVEL:
|
||||
raise InvalidCondition(
|
||||
f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed"
|
||||
raise ValidationError(
|
||||
field_name=field_name,
|
||||
message=f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed",
|
||||
)
|
||||
cls._validate_multi_condition_nesting(
|
||||
conditions=condition.conditions,
|
||||
field_name=field_name,
|
||||
current_level=level,
|
||||
)
|
||||
|
||||
|
||||
class ExecutionCall(_Serializable, ABC):
|
||||
class InvalidExecutionCall(ValueError):
|
||||
pass
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
# validate call using marshmallow schema before creating
|
||||
errors = self.Schema().validate(data=self.to_dict())
|
||||
if errors:
|
||||
error_message = extract_single_error_message_from_schema_errors(errors)
|
||||
raise self.InvalidExecutionCall(f"{error_message}")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -10,14 +10,22 @@ from typing import (
|
|||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
|
||||
from marshmallow import (
|
||||
ValidationError,
|
||||
fields,
|
||||
post_load,
|
||||
validate,
|
||||
validates,
|
||||
validates_schema,
|
||||
)
|
||||
from marshmallow.validate import OneOf
|
||||
from typing_extensions import override
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.middleware import geth_poa_middleware
|
||||
from web3.providers import BaseProvider
|
||||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
|
||||
from nucypher.policy.conditions.base import (
|
||||
ExecutionCall,
|
||||
)
|
||||
|
@ -26,7 +34,6 @@ from nucypher.policy.conditions.context import (
|
|||
resolve_parameter_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
|
@ -38,11 +45,10 @@ from nucypher.policy.conditions.lingo import (
|
|||
)
|
||||
from nucypher.policy.conditions.utils import camel_case_to_snake
|
||||
from nucypher.policy.conditions.validation import (
|
||||
_align_comparator_value_with_abi,
|
||||
_get_abi_types,
|
||||
_validate_contract_call_abi,
|
||||
_validate_multiple_output_types,
|
||||
_validate_single_output_type,
|
||||
align_comparator_value_with_abi,
|
||||
get_unbound_contract_function,
|
||||
validate_contract_function_expected_return_type,
|
||||
validate_function_abi,
|
||||
)
|
||||
|
||||
# TODO: Move this to a more appropriate location,
|
||||
|
@ -61,53 +67,6 @@ _CONDITION_CHAINS = {
|
|||
}
|
||||
|
||||
|
||||
def _resolve_abi(
|
||||
w3: Web3,
|
||||
method: str,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
) -> ABIFunction:
|
||||
"""Resolves the contract an/or function ABI from a standard contract name"""
|
||||
|
||||
if not (function_abi or standard_contract_type):
|
||||
raise InvalidCondition(
|
||||
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
|
||||
)
|
||||
|
||||
if standard_contract_type:
|
||||
try:
|
||||
# Lookup the standard ABI given it's ERC standard name (standard contract type)
|
||||
contract_abi = STANDARD_ABIS[standard_contract_type]
|
||||
except KeyError:
|
||||
raise InvalidCondition(
|
||||
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract all function ABIs from the contract's ABI.
|
||||
# Will raise a ValueError if there is not exactly one match.
|
||||
function_abi = (
|
||||
w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
|
||||
)
|
||||
except ValueError as e:
|
||||
raise InvalidCondition(str(e))
|
||||
|
||||
return ABIFunction(function_abi)
|
||||
|
||||
|
||||
def _validate_chain(chain: int) -> None:
|
||||
if not isinstance(chain, int):
|
||||
raise ValueError(
|
||||
f'The "chain" field of a condition must be the '
|
||||
f'integer chain ID (got "{chain}").'
|
||||
)
|
||||
if chain not in _CONDITION_CHAINS:
|
||||
raise InvalidCondition(
|
||||
f"chain ID {chain} is not a permitted "
|
||||
f"blockchain for condition evaluation."
|
||||
)
|
||||
|
||||
|
||||
class RPCCall(ExecutionCall):
|
||||
LOG = logging.Logger(__name__)
|
||||
|
||||
|
@ -116,28 +75,47 @@ class RPCCall(ExecutionCall):
|
|||
"eth_getBalance": int,
|
||||
} # TODO other allowed methods (tDEC #64)
|
||||
|
||||
class Schema(ExecutionCall.Schema):
|
||||
chain = fields.Int(required=True, strict=True)
|
||||
method = fields.Str(
|
||||
required=True,
|
||||
error_messages={
|
||||
"required": "Undefined method name",
|
||||
"null": "Undefined method name",
|
||||
},
|
||||
)
|
||||
parameters = fields.List(
|
||||
fields.Field, attribute="parameters", required=False, allow_none=True
|
||||
)
|
||||
|
||||
@validates("chain")
|
||||
def validate_chain(self, value):
|
||||
if value not in _CONDITION_CHAINS:
|
||||
raise ValidationError(
|
||||
f"chain ID {value} is not a permitted blockchain for condition evaluation"
|
||||
)
|
||||
|
||||
@validates("method")
|
||||
def validate_method(self, value):
|
||||
if value not in RPCCall.ALLOWED_METHODS:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a permitted RPC endpoint for condition evaluation."
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return RPCCall(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: int,
|
||||
method: str,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
# Validate input
|
||||
_validate_chain(chain=chain)
|
||||
|
||||
self.chain = chain
|
||||
self.method = self._validate_method(method=method)
|
||||
self.parameters = parameters or None
|
||||
|
||||
def _validate_method(self, method):
|
||||
if not method:
|
||||
raise ValueError("Undefined method name")
|
||||
|
||||
if method not in self.ALLOWED_METHODS:
|
||||
raise ValueError(
|
||||
f"'{method}' is not a permitted RPC endpoint for condition evaluation."
|
||||
)
|
||||
return method
|
||||
self.method = method
|
||||
self.parameters = parameters
|
||||
super().__init__()
|
||||
|
||||
def _get_web3_py_function(self, w3: Web3, rpc_method: str):
|
||||
web3_py_method = camel_case_to_snake(rpc_method)
|
||||
|
@ -226,17 +204,30 @@ class RPCCall(ExecutionCall):
|
|||
|
||||
|
||||
class RPCCondition(ExecutionCallAccessControlCondition):
|
||||
EXECUTION_CALL_TYPE = RPCCall
|
||||
CONDITION_TYPE = ConditionType.RPC.value
|
||||
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema):
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema, RPCCall.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.RPC.value), required=True
|
||||
)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
|
||||
@validates_schema()
|
||||
def validate_expected_return_type(self, data, **kwargs):
|
||||
method = data.get("method")
|
||||
return_value_test = data.get("return_value_test")
|
||||
|
||||
expected_return_type = RPCCall.ALLOWED_METHODS[method]
|
||||
comparator_value = return_value_test.value
|
||||
if is_context_variable(comparator_value):
|
||||
return
|
||||
|
||||
if not isinstance(return_value_test.value, expected_return_type):
|
||||
raise ValidationError(
|
||||
field_name="return_value_test",
|
||||
message=f"Return value comparison for '{method}' call output "
|
||||
f"should be '{expected_return_type}' and not '{type(comparator_value)}'.",
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -248,16 +239,25 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
chain: int,
|
||||
method: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
condition_type: str = ConditionType.RPC.value,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
self._validate_expected_return_type()
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return RPCCall(*args, **kwargs)
|
||||
super().__init__(
|
||||
chain=chain,
|
||||
method=method,
|
||||
return_value_test=return_value_test,
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
|
@ -271,18 +271,6 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
def parameters(self):
|
||||
return self.execution_call.parameters
|
||||
|
||||
def _validate_expected_return_type(self):
|
||||
expected_return_type = RPCCall.ALLOWED_METHODS[self.method]
|
||||
comparator_value = self.return_value_test.value
|
||||
if is_context_variable(comparator_value):
|
||||
return
|
||||
|
||||
if not isinstance(self.return_value_test.value, expected_return_type):
|
||||
raise InvalidCondition(
|
||||
f"Return value comparison for '{self.method}' call output "
|
||||
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
|
||||
)
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
|
@ -297,6 +285,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
return_value_test = self._align_comparator_value_with_abi(
|
||||
resolved_return_value_test
|
||||
)
|
||||
|
||||
result = self.execution_call.execute(providers=providers, **context)
|
||||
|
||||
eval_result = return_value_test.eval(result) # test
|
||||
|
@ -304,6 +293,79 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
|
||||
|
||||
class ContractCall(RPCCall):
|
||||
class Schema(RPCCall.Schema):
|
||||
contract_address = fields.Str(required=True)
|
||||
standard_contract_type = fields.Str(
|
||||
required=False,
|
||||
validate=OneOf(
|
||||
STANDARD_ABI_CONTRACT_TYPES,
|
||||
error="Invalid standard contract type: {input}",
|
||||
),
|
||||
allow_none=True,
|
||||
)
|
||||
function_abi = fields.Dict(required=False, allow_none=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return ContractCall(**data)
|
||||
|
||||
@validates("contract_address")
|
||||
def validate_contract_address(self, value):
|
||||
try:
|
||||
to_checksum_address(value)
|
||||
except ValueError:
|
||||
raise ValidationError(f"Invalid checksum address: '{value}'")
|
||||
|
||||
@override
|
||||
@validates("method")
|
||||
def validate_method(self, value):
|
||||
return
|
||||
|
||||
@validates("function_abi")
|
||||
def validate_abi(self, value):
|
||||
# needs to be done before schema validation
|
||||
if value:
|
||||
try:
|
||||
validate_function_abi(value)
|
||||
except ValueError as e:
|
||||
raise ValidationError(
|
||||
field_name="function_abi", message=str(e)
|
||||
) from e
|
||||
|
||||
@validates_schema
|
||||
def validate_standard_contract_type_or_function_abi(self, data, **kwargs):
|
||||
method = data.get("method")
|
||||
standard_contract_type = data.get("standard_contract_type")
|
||||
function_abi = data.get("function_abi")
|
||||
|
||||
# validate xor of standard contract type and function abi
|
||||
if not (bool(standard_contract_type) ^ bool(function_abi)):
|
||||
raise ValidationError(
|
||||
field_name="standard_contract_type",
|
||||
message=f"Provide a standard contract type or function ABI; got ({standard_contract_type}, {function_abi}).",
|
||||
)
|
||||
|
||||
# validate function abi with method name (not available for field validation)
|
||||
if function_abi:
|
||||
try:
|
||||
validate_function_abi(function_abi, method_name=method)
|
||||
except ValueError as e:
|
||||
raise ValidationError(
|
||||
field_name="function_abi", message=str(e)
|
||||
) from e
|
||||
|
||||
# validate contract
|
||||
contract_address = to_checksum_address(data.get("contract_address"))
|
||||
try:
|
||||
get_unbound_contract_function(
|
||||
contract_address=contract_address,
|
||||
method=method,
|
||||
standard_contract_type=standard_contract_type,
|
||||
function_abi=function_abi,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValidationError(str(e)) from e
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
|
@ -313,13 +375,6 @@ class ContractCall(RPCCall):
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not method:
|
||||
raise ValueError("Undefined method name")
|
||||
|
||||
_validate_contract_call_abi(
|
||||
standard_contract_type, function_abi, method_name=method
|
||||
)
|
||||
|
||||
# preprocessing
|
||||
contract_address = to_checksum_address(contract_address)
|
||||
self.contract_address = contract_address
|
||||
|
@ -327,30 +382,14 @@ class ContractCall(RPCCall):
|
|||
self.function_abi = function_abi
|
||||
|
||||
super().__init__(method=method, *args, **kwargs)
|
||||
self.contract_function = self._get_unbound_contract_function()
|
||||
|
||||
def _validate_method(self, method):
|
||||
return method
|
||||
|
||||
def _get_unbound_contract_function(self) -> ContractFunction:
|
||||
"""Gets an unbound contract function to evaluate for this condition"""
|
||||
w3 = Web3()
|
||||
function_abi = _resolve_abi(
|
||||
w3=w3,
|
||||
standard_contract_type=self.standard_contract_type,
|
||||
# contract function already validated - so should not raise an exception
|
||||
self.contract_function = get_unbound_contract_function(
|
||||
contract_address=self.contract_address,
|
||||
method=self.method,
|
||||
standard_contract_type=self.standard_contract_type,
|
||||
function_abi=self.function_abi,
|
||||
)
|
||||
try:
|
||||
contract = w3.eth.contract(
|
||||
address=self.contract_address, abi=[function_abi]
|
||||
)
|
||||
contract_function = getattr(contract.functions, self.method)
|
||||
return contract_function
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to find contract function, '{self.method}', for condition: {e}"
|
||||
)
|
||||
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
|
@ -363,42 +402,63 @@ class ContractCall(RPCCall):
|
|||
|
||||
|
||||
class ContractCondition(RPCCondition):
|
||||
EXECUTION_CALL_TYPE = ContractCall
|
||||
CONDITION_TYPE = ConditionType.CONTRACT.value
|
||||
|
||||
class Schema(RPCCondition.Schema):
|
||||
class Schema(RPCCondition.Schema, ContractCall.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
|
||||
)
|
||||
contract_address = fields.Str(required=True)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
function_abi = fields.Dict(required=False)
|
||||
|
||||
@validates_schema()
|
||||
def validate_expected_return_type(self, data, **kwargs):
|
||||
# validate that contract function is correct
|
||||
try:
|
||||
contract_function = get_unbound_contract_function(
|
||||
contract_address=data.get("contract_address"),
|
||||
method=data.get("method"),
|
||||
standard_contract_type=data.get("standard_contract_type"),
|
||||
function_abi=data.get("function_abi"),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValidationError(str(e)) from e
|
||||
|
||||
# validate return type based on contract function
|
||||
return_value_test = data.get("return_value_test")
|
||||
try:
|
||||
validate_contract_function_expected_return_type(
|
||||
contract_function=contract_function,
|
||||
return_value_test=return_value_test,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValidationError(
|
||||
field_name="return_value_test",
|
||||
message=str(e),
|
||||
) from e
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return ContractCondition(**data)
|
||||
|
||||
@validates_schema
|
||||
def check_standard_contract_type_or_function_abi(self, data, **kwargs):
|
||||
standard_contract_type = data.get("standard_contract_type")
|
||||
function_abi = data.get("function_abi")
|
||||
try:
|
||||
_validate_contract_call_abi(
|
||||
standard_contract_type, function_abi, method_name=data.get("method")
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValidationError(str(e))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
method: str,
|
||||
contract_address: ChecksumAddress,
|
||||
condition_type: str = ConditionType.CONTRACT.value,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# call to super must be at the end for proper validation
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return ContractCall(*args, **kwargs)
|
||||
super().__init__(
|
||||
method=method,
|
||||
condition_type=condition_type,
|
||||
contract_address=contract_address,
|
||||
standard_contract_type=standard_contract_type,
|
||||
function_abi=function_abi,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def function_abi(self):
|
||||
|
@ -416,12 +476,6 @@ class ContractCondition(RPCCondition):
|
|||
def contract_address(self):
|
||||
return self.execution_call.contract_address
|
||||
|
||||
def _validate_expected_return_type(self) -> None:
|
||||
_validate_contract_function_expected_return_type(
|
||||
contract_function=self.contract_function,
|
||||
return_value_test=self.return_value_test,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r = (
|
||||
f"{self.__class__.__name__}(function={self.method}, "
|
||||
|
@ -433,29 +487,7 @@ class ContractCondition(RPCCondition):
|
|||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
return _align_comparator_value_with_abi(
|
||||
return align_comparator_value_with_abi(
|
||||
abi=self.contract_function.contract_abi[0],
|
||||
return_value_test=return_value_test,
|
||||
)
|
||||
|
||||
|
||||
def _validate_contract_function_expected_return_type(
|
||||
contract_function: ContractFunction, return_value_test: ReturnValueTest
|
||||
) -> None:
|
||||
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
|
||||
comparator_value = return_value_test.value
|
||||
comparator_index = return_value_test.index
|
||||
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
|
||||
failure_message = (
|
||||
f"Invalid return value comparison type '{type(comparator_value)}' for "
|
||||
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
|
||||
)
|
||||
|
||||
if len(output_abi_types) == 1:
|
||||
_validate_single_output_type(
|
||||
output_abi_types[0], comparator_value, comparator_index, failure_message
|
||||
)
|
||||
else:
|
||||
_validate_multiple_output_types(
|
||||
output_abi_types, comparator_value, comparator_index, failure_message
|
||||
)
|
||||
|
|
|
@ -2,10 +2,9 @@ import ast
|
|||
import base64
|
||||
import json
|
||||
import operator as pyoperator
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from marshmallow import (
|
||||
|
@ -103,25 +102,29 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
|
|||
def _validate_operator_and_operands(
|
||||
cls,
|
||||
operator: str,
|
||||
operands: List[Union[Dict, AccessControlCondition]],
|
||||
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
|
||||
operands: List[AccessControlCondition],
|
||||
):
|
||||
if operator not in cls.OPERATORS:
|
||||
raise exception_class(f"{operator} is not a valid operator")
|
||||
raise ValidationError(
|
||||
field_name="operator", message=f"{operator} is not a valid operator"
|
||||
)
|
||||
|
||||
num_operands = len(operands)
|
||||
if operator == cls.NOT_OPERATOR:
|
||||
if num_operands != 1:
|
||||
raise exception_class(
|
||||
f"Only 1 operand permitted for '{operator}' compound condition"
|
||||
raise ValidationError(
|
||||
field_name="operands",
|
||||
message=f"Only 1 operand permitted for '{operator}' compound condition",
|
||||
)
|
||||
elif num_operands < 2:
|
||||
raise exception_class(
|
||||
f"Minimum of 2 operand needed for '{operator}' compound condition"
|
||||
raise ValidationError(
|
||||
field_name="operands",
|
||||
message=f"Minimum of 2 operand needed for '{operator}' compound condition",
|
||||
)
|
||||
elif num_operands > cls.MAX_NUM_CONDITIONS:
|
||||
raise exception_class(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition"
|
||||
raise ValidationError(
|
||||
field_name="operands",
|
||||
message="Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition",
|
||||
)
|
||||
|
||||
|
||||
|
@ -141,7 +144,10 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
|
|||
operator = data["operator"]
|
||||
operands = data["operands"]
|
||||
CompoundAccessControlCondition._validate_operator_and_operands(
|
||||
operator, operands, ValidationError
|
||||
operator, operands
|
||||
)
|
||||
CompoundAccessControlCondition._validate_multi_condition_nesting(
|
||||
conditions=operands, field_name="operands"
|
||||
)
|
||||
|
||||
@post_load
|
||||
|
@ -161,15 +167,15 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
|
|||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
self._validate_operator_and_operands(operator, operands, InvalidCondition)
|
||||
|
||||
self.operator = operator
|
||||
self.operands = operands
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
self.id = md5(bytes(self)).hexdigest()[:6]
|
||||
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
super().__init__(
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
)
|
||||
|
||||
self.id = md5(bytes(self)).hexdigest()[:6]
|
||||
|
||||
def __repr__(self):
|
||||
return f"Operator={self.operator} (NumOperands={len(self.operands)}), id={self.id})"
|
||||
|
@ -265,18 +271,30 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
|
|||
@classmethod
|
||||
def _validate_condition_variables(
|
||||
cls,
|
||||
condition_variables: List[Union[Dict, ConditionVariable]],
|
||||
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
|
||||
condition_variables: List[ConditionVariable],
|
||||
):
|
||||
num_condition_variables = len(condition_variables)
|
||||
if num_condition_variables < 2:
|
||||
raise exception_class("At least two conditions must be specified")
|
||||
|
||||
if num_condition_variables > cls.MAX_NUM_CONDITIONS:
|
||||
raise exception_class(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
|
||||
if not condition_variables or len(condition_variables) < 2:
|
||||
raise ValidationError(
|
||||
field_name="condition_variables",
|
||||
message="At least two conditions must be specified",
|
||||
)
|
||||
|
||||
if len(condition_variables) > cls.MAX_NUM_CONDITIONS:
|
||||
raise ValidationError(
|
||||
field_name="condition_variables",
|
||||
message=f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed",
|
||||
)
|
||||
|
||||
# check for duplicate var names
|
||||
var_names = set()
|
||||
for condition_variable in condition_variables:
|
||||
if condition_variable.var_name in var_names:
|
||||
raise ValidationError(
|
||||
field_name="condition_variables",
|
||||
message=f"Duplicate variable names are not allowed - {condition_variable.var_name}",
|
||||
)
|
||||
var_names.add(condition_variable.var_name)
|
||||
|
||||
class Schema(AccessControlCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
|
||||
|
@ -289,11 +307,12 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
|
|||
class Meta:
|
||||
ordered = True
|
||||
|
||||
@validates_schema
|
||||
def validate_condition_variables(self, data, **kwargs):
|
||||
condition_variables = data["condition_variables"]
|
||||
SequentialAccessControlCondition._validate_condition_variables(
|
||||
condition_variables, ValidationError
|
||||
@validates("condition_variables")
|
||||
def validate_condition_variables(self, value):
|
||||
SequentialAccessControlCondition._validate_condition_variables(value)
|
||||
conditions = [cv.condition for cv in value]
|
||||
SequentialAccessControlCondition._validate_multi_condition_nesting(
|
||||
conditions=conditions, field_name="condition_variables"
|
||||
)
|
||||
|
||||
@post_load
|
||||
|
@ -306,11 +325,11 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
|
|||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
self._validate_condition_variables(
|
||||
condition_variables=condition_variables, exception_class=InvalidCondition
|
||||
)
|
||||
self.condition_variables = condition_variables
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
super().__init__(
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
|
||||
|
@ -358,7 +377,9 @@ class ReturnValueTest:
|
|||
value = fields.Raw(
|
||||
allow_none=False, required=True
|
||||
) # any valid type (excludes None)
|
||||
index = fields.Int(strict=True, required=False, validate=Range(min=0))
|
||||
index = fields.Int(
|
||||
strict=True, required=False, validate=Range(min=0), allow_none=True
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -574,6 +595,8 @@ class ExecutionCallAccessControlCondition(AccessControlCondition):
|
|||
Conditions that utilize underlying ExecutionCall objects.
|
||||
"""
|
||||
|
||||
EXECUTION_CALL_TYPE = NotImplemented
|
||||
|
||||
class Schema(AccessControlCondition.Schema):
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
|
@ -588,19 +611,14 @@ class ExecutionCallAccessControlCondition(AccessControlCondition):
|
|||
**kwargs,
|
||||
):
|
||||
self.return_value_test = return_value_test
|
||||
|
||||
try:
|
||||
self.execution_call = self._create_execution_call(*args, **kwargs)
|
||||
except ValueError as e:
|
||||
raise InvalidCondition(str(e))
|
||||
self.execution_call = self.EXECUTION_CALL_TYPE(*args, **kwargs)
|
||||
except ExecutionCall.InvalidExecutionCall as e:
|
||||
raise InvalidCondition(str(e)) from e
|
||||
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
|
||||
@abstractmethod
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
"""
|
||||
Returns the execution call that the condition executes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionType,
|
||||
ExecutionCallAccessControlCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
@ -37,6 +38,15 @@ class JSONPathField(Field):
|
|||
class JsonApiCall(ExecutionCall):
|
||||
TIMEOUT = 5 # seconds
|
||||
|
||||
class Schema(ExecutionCall.Schema):
|
||||
endpoint = Url(required=True, relative=False, schemes=["https"])
|
||||
parameters = fields.Dict(required=False, allow_none=True)
|
||||
query = JSONPathField(required=False, allow_none=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return JsonApiCall(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
|
@ -50,6 +60,8 @@ class JsonApiCall(ExecutionCall):
|
|||
self.timeout = self.TIMEOUT
|
||||
self.logger = Logger(__name__)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
response = self._fetch()
|
||||
data = self._deserialize_response(response)
|
||||
|
@ -129,15 +141,13 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
|
|||
The response will be deserialized as JSON and parsed using jsonpath.
|
||||
"""
|
||||
|
||||
EXECUTION_CALL_TYPE = JsonApiCall
|
||||
CONDITION_TYPE = ConditionType.JSONAPI.value
|
||||
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema):
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema, JsonApiCall.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
|
||||
)
|
||||
endpoint = Url(required=True, relative=False, schemes=["https"])
|
||||
parameters = fields.Dict(required=False, allow_none=True)
|
||||
query = JSONPathField(required=False, allow_none=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -145,14 +155,21 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
query: Optional[str] = None,
|
||||
parameters: Optional[dict] = None,
|
||||
condition_type: str = ConditionType.JSONAPI.value,
|
||||
*args,
|
||||
**kwargs,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return JsonApiCall(*args, **kwargs)
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
return_value_test=return_value_test,
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
)
|
||||
|
||||
@property
|
||||
def endpoint(self):
|
||||
|
|
|
@ -1,36 +1,51 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from marshmallow import fields, post_load, validate
|
||||
from marshmallow.validate import Equal
|
||||
from marshmallow import (
|
||||
ValidationError,
|
||||
fields,
|
||||
post_load,
|
||||
validate,
|
||||
validates,
|
||||
validates_schema,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import ConditionType
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
|
||||
|
||||
class TimeRPCCall(RPCCall):
|
||||
METHOD = "blocktime"
|
||||
|
||||
class Schema(RPCCall.Schema):
|
||||
method = fields.Str(dump_default="blocktime", required=True)
|
||||
|
||||
@override
|
||||
@validates("method")
|
||||
def validate_method(self, value):
|
||||
if value != TimeRPCCall.METHOD:
|
||||
raise ValidationError(f"method name must be {TimeRPCCall.METHOD}.")
|
||||
|
||||
@validates("parameters")
|
||||
def validate_no_parameters(self, value):
|
||||
if value:
|
||||
raise ValidationError(
|
||||
f"{TimeRPCCall.METHOD}' does not take any parameters"
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return TimeRPCCall(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: int,
|
||||
method: str = METHOD,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
if parameters:
|
||||
raise ValueError(f"{self.METHOD} does not take any parameters")
|
||||
|
||||
super().__init__(chain=chain, method=method, parameters=parameters)
|
||||
|
||||
def _validate_method(self, method):
|
||||
if method != self.METHOD:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
|
||||
)
|
||||
return method
|
||||
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
# TODO may need to rethink as part of #3051 (multicall work).
|
||||
|
@ -39,15 +54,23 @@ class TimeRPCCall(RPCCall):
|
|||
|
||||
|
||||
class TimeCondition(RPCCondition):
|
||||
EXECUTION_CALL_TYPE = TimeRPCCall
|
||||
CONDITION_TYPE = ConditionType.TIME.value
|
||||
|
||||
class Schema(RPCCondition.Schema):
|
||||
class Schema(RPCCondition.Schema, TimeRPCCall.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.TIME.value), required=True
|
||||
)
|
||||
method = fields.Str(
|
||||
dump_default="blocktime", required=True, validate=Equal("blocktime")
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_expected_return_type(self, data, **kwargs):
|
||||
return_value_test = data.get("return_value_test")
|
||||
comparator_value = return_value_test.value
|
||||
if not isinstance(comparator_value, int):
|
||||
raise ValidationError(
|
||||
field_name="return_value_test",
|
||||
message=f"Invalid return value comparison type '{type(comparator_value)}'; must be an integer",
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -59,29 +82,21 @@ class TimeCondition(RPCCondition):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
return_value_test: ReturnValueTest,
|
||||
chain: int,
|
||||
method: str = TimeRPCCall.METHOD,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
*args,
|
||||
**kwargs,
|
||||
condition_type: str = ConditionType.TIME.value,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
# call to super must be at the end for proper validation
|
||||
super().__init__(
|
||||
condition_type=condition_type,
|
||||
return_value_test=return_value_test,
|
||||
chain=chain,
|
||||
method=method,
|
||||
*args,
|
||||
**kwargs,
|
||||
condition_type=condition_type,
|
||||
name=name,
|
||||
)
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return TimeRPCCall(*args, **kwargs)
|
||||
|
||||
def _validate_expected_return_type(self):
|
||||
comparator_value = self.return_value_test.value
|
||||
if not isinstance(comparator_value, int):
|
||||
raise InvalidCondition(
|
||||
f"Invalid return value comparison type '{type(comparator_value)}'; must be an integer"
|
||||
)
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
return self.return_value_test.value
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from marshmallow.exceptions import SCHEMA
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
|
@ -138,3 +139,33 @@ def evaluate_condition_lingo(
|
|||
if error:
|
||||
log.info(error.message) # log error message
|
||||
raise error
|
||||
|
||||
|
||||
def extract_single_error_message_from_schema_errors(
|
||||
errors: Dict[str, List[str]]
|
||||
) -> str:
|
||||
"""
|
||||
Extract single error message from Schema().validate() errors result.
|
||||
|
||||
The result is only for a single error type, and only the first message string for that type.
|
||||
If there are multiple error types, only one error type is used; the first field-specific (@validates)
|
||||
error type encountered is prioritized over any schema-level-specific (@validates_schema) error.
|
||||
"""
|
||||
if not errors:
|
||||
raise ValueError("Validation errors must be provided")
|
||||
|
||||
# extract error type - either field-specific (preferred) or schema-specific
|
||||
error_key_to_use = None
|
||||
for error_type in list(errors.keys()):
|
||||
error_key_to_use = error_type
|
||||
if error_key_to_use != SCHEMA:
|
||||
# actual field
|
||||
break
|
||||
|
||||
message = errors[error_key_to_use][0]
|
||||
message_prefix = (
|
||||
f"'{camel_case_to_snake(error_key_to_use)}' field - "
|
||||
if error_key_to_use != SCHEMA
|
||||
else ""
|
||||
)
|
||||
return f"{message_prefix}{message}"
|
||||
|
|
|
@ -7,42 +7,24 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from web3 import Web3
|
||||
from web3.auto import w3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
|
||||
from nucypher.policy.conditions.context import is_context_variable
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ReturnValueTest
|
||||
|
||||
|
||||
def _validate_single_output_type(
|
||||
expected_type: str,
|
||||
comparator_value: Any,
|
||||
comparator_index: Optional[int],
|
||||
failure_message: str,
|
||||
) -> None:
|
||||
if comparator_index is not None and _is_tuple_type(expected_type):
|
||||
type_entries = _get_tuple_type_entries(expected_type)
|
||||
expected_type = type_entries[comparator_index]
|
||||
_validate_value_type(expected_type, comparator_value, failure_message)
|
||||
|
||||
#
|
||||
# Schema logic
|
||||
#
|
||||
|
||||
def _get_abi_types(abi: ABIFunction) -> List[str]:
|
||||
return [_collapse_if_tuple(cast(Dict[str, Any], arg)) for arg in abi["outputs"]]
|
||||
|
||||
|
||||
def _validate_value_type(
|
||||
expected_type: str, comparator_value: Any, failure_message: str
|
||||
) -> None:
|
||||
if is_context_variable(comparator_value):
|
||||
# context variable types cannot be known until execution time.
|
||||
return
|
||||
if not w3.is_encodable(expected_type, comparator_value):
|
||||
raise InvalidCondition(failure_message)
|
||||
|
||||
|
||||
def _collapse_if_tuple(abi: Dict[str, Any]) -> str:
|
||||
abi_type = abi["type"]
|
||||
if not abi_type.startswith("tuple"):
|
||||
|
@ -66,6 +48,28 @@ def _get_tuple_type_entries(tuple_type: str) -> List[str]:
|
|||
return result
|
||||
|
||||
|
||||
def _validate_value_type(
|
||||
expected_type: str, comparator_value: Any, failure_message: str
|
||||
) -> None:
|
||||
if is_context_variable(comparator_value):
|
||||
# context variable types cannot be known until execution time.
|
||||
return
|
||||
if not w3.is_encodable(expected_type, comparator_value):
|
||||
raise ValueError(failure_message)
|
||||
|
||||
|
||||
def _validate_single_output_type(
|
||||
expected_type: str,
|
||||
comparator_value: Any,
|
||||
comparator_index: Optional[int],
|
||||
failure_message: str,
|
||||
) -> None:
|
||||
if comparator_index is not None and _is_tuple_type(expected_type):
|
||||
type_entries = _get_tuple_type_entries(expected_type)
|
||||
expected_type = type_entries[comparator_index]
|
||||
_validate_value_type(expected_type, comparator_value, failure_message)
|
||||
|
||||
|
||||
def _validate_multiple_output_types(
|
||||
output_abi_types: List[str],
|
||||
comparator_value: Any,
|
||||
|
@ -82,15 +86,46 @@ def _validate_multiple_output_types(
|
|||
return
|
||||
|
||||
if not isinstance(comparator_value, Sequence):
|
||||
raise InvalidCondition(failure_message)
|
||||
raise ValueError(failure_message)
|
||||
|
||||
if len(output_abi_types) != len(comparator_value):
|
||||
raise InvalidCondition(failure_message)
|
||||
raise ValueError(failure_message)
|
||||
|
||||
for output_abi_type, component_value in zip(output_abi_types, comparator_value):
|
||||
_validate_value_type(output_abi_type, component_value, failure_message)
|
||||
|
||||
|
||||
def _resolve_abi(
|
||||
w3: Web3,
|
||||
method: str,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
) -> ABIFunction:
|
||||
"""Resolves the contract an/or function ABI from a standard contract name"""
|
||||
|
||||
if not (function_abi or standard_contract_type):
|
||||
raise ValueError(
|
||||
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
|
||||
)
|
||||
|
||||
if standard_contract_type:
|
||||
try:
|
||||
# Lookup the standard ABI given it's ERC standard name (standard contract type)
|
||||
contract_abi = STANDARD_ABIS[standard_contract_type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
|
||||
)
|
||||
|
||||
# Extract all function ABIs from the contract's ABI.
|
||||
# Will raise a ValueError if there is not exactly one match.
|
||||
function_abi = (
|
||||
w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
|
||||
)
|
||||
|
||||
return ABIFunction(function_abi)
|
||||
|
||||
|
||||
def _align_comparator_value_single_output(
|
||||
expected_type: str, comparator_value: Any, comparator_index: Optional[int]
|
||||
) -> Any:
|
||||
|
@ -99,7 +134,7 @@ def _align_comparator_value_single_output(
|
|||
expected_type = type_entries[comparator_index]
|
||||
|
||||
if not w3.is_encodable(expected_type, comparator_value):
|
||||
raise InvalidCondition(
|
||||
raise ValueError(
|
||||
f"Mismatched comparator type ({comparator_value} as {expected_type})"
|
||||
)
|
||||
return comparator_value
|
||||
|
@ -112,7 +147,7 @@ def _align_comparator_value_multiple_output(
|
|||
expected_type = output_abi_types[comparator_index]
|
||||
# ensure alignment
|
||||
if not w3.is_encodable(expected_type, comparator_value):
|
||||
raise InvalidCondition(
|
||||
raise ValueError(
|
||||
f"Mismatched comparator type ({comparator_value} as {expected_type})"
|
||||
)
|
||||
|
||||
|
@ -122,15 +157,20 @@ def _align_comparator_value_multiple_output(
|
|||
for output_abi_type, component_value in zip(output_abi_types, comparator_value):
|
||||
# ensure alignment
|
||||
if not w3.is_encodable(output_abi_type, component_value):
|
||||
raise InvalidCondition(
|
||||
raise ValueError(
|
||||
f"Mismatched comparator type ({component_value} as {output_abi_type})"
|
||||
)
|
||||
values.append(component_value)
|
||||
return values
|
||||
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
abi, return_value_test: ReturnValueTest
|
||||
#
|
||||
# Public functions.
|
||||
#
|
||||
|
||||
|
||||
def align_comparator_value_with_abi(
|
||||
abi: ABIFunction, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
output_abi_types = _get_abi_types(abi)
|
||||
comparator = return_value_test.comparator
|
||||
|
@ -155,13 +195,19 @@ def _align_comparator_value_with_abi(
|
|||
)
|
||||
|
||||
|
||||
def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
|
||||
"""validates a dictionary as valid for use as a condition function ABI"""
|
||||
def validate_function_abi(
|
||||
function_abi: Dict, method_name: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Validates a dictionary as valid for use as a condition function ABI.
|
||||
|
||||
Optionally validates the method_name
|
||||
"""
|
||||
abi = ABIFunction(function_abi)
|
||||
|
||||
if not abi.get("name"):
|
||||
raise ValueError(f"Invalid ABI, no function name found {abi}")
|
||||
if abi.get("name") != method_name:
|
||||
if method_name and abi.get("name") != method_name:
|
||||
raise ValueError(f"Mismatched ABI for contract function {method_name} - {abi}")
|
||||
if abi.get("type") != "function":
|
||||
raise ValueError(f"Invalid ABI type {abi}")
|
||||
|
@ -171,14 +217,47 @@ def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
|
|||
raise ValueError(f"Invalid ABI stateMutability {abi}")
|
||||
|
||||
|
||||
def _validate_contract_call_abi(
|
||||
standard_contract_type: str,
|
||||
function_abi: Dict,
|
||||
method_name: str,
|
||||
) -> None:
|
||||
if not (bool(standard_contract_type) ^ bool(function_abi)):
|
||||
def get_unbound_contract_function(
|
||||
contract_address: ChecksumAddress,
|
||||
method: str,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
) -> ContractFunction:
|
||||
"""Gets an unbound contract function to evaluate"""
|
||||
w3 = Web3()
|
||||
function_abi = _resolve_abi(
|
||||
w3=w3,
|
||||
standard_contract_type=standard_contract_type,
|
||||
method=method,
|
||||
function_abi=function_abi,
|
||||
)
|
||||
try:
|
||||
contract = w3.eth.contract(address=contract_address, abi=[function_abi])
|
||||
contract_function = getattr(contract.functions, method)
|
||||
return contract_function
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
|
||||
f"Unable to find contract function, '{method}', for condition: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def validate_contract_function_expected_return_type(
|
||||
contract_function: ContractFunction, return_value_test: ReturnValueTest
|
||||
) -> None:
|
||||
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
|
||||
comparator_value = return_value_test.value
|
||||
comparator_index = return_value_test.index
|
||||
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
|
||||
failure_message = (
|
||||
f"Invalid return value comparison type '{type(comparator_value)}' for "
|
||||
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
|
||||
)
|
||||
|
||||
if len(output_abi_types) == 1:
|
||||
_validate_single_output_type(
|
||||
output_abi_types[0], comparator_value, comparator_index, failure_message
|
||||
)
|
||||
else:
|
||||
_validate_multiple_output_types(
|
||||
output_abi_types, comparator_value, comparator_index, failure_message
|
||||
)
|
||||
if function_abi:
|
||||
_validate_function_abi(function_abi, method_name=method_name)
|
||||
|
|
|
@ -4,7 +4,10 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
AndCompoundCondition,
|
||||
CompoundAccessControlCondition,
|
||||
|
@ -121,39 +124,39 @@ def test_compound_condition_schema_validation(operator, time_condition, rpc_cond
|
|||
compound_condition_dict = compound_condition.to_dict()
|
||||
|
||||
# no issues here
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
# no issues with optional name
|
||||
compound_condition_dict["name"] = "my_contract_condition"
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# incorrect condition type
|
||||
compound_condition_dict = compound_condition.to_dict()
|
||||
compound_condition_dict["condition_type"] = ConditionType.RPC.value
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# invalid operator
|
||||
compound_condition_dict = compound_condition.to_dict()
|
||||
compound_condition_dict["operator"] = "5True"
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no operator
|
||||
compound_condition_dict = compound_condition.to_dict()
|
||||
del compound_condition_dict["operator"]
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no operands
|
||||
compound_condition_dict = compound_condition.to_dict()
|
||||
del compound_condition_dict["operands"]
|
||||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
CompoundAccessControlCondition.from_dict(compound_condition_dict)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_and_condition_and_short_circuit(mocker, mock_conditions):
|
||||
def test_and_condition_and_short_circuit(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
and_condition = AndCompoundCondition(
|
||||
|
@ -286,10 +289,9 @@ def test_compound_condition(mock_conditions):
|
|||
] # or-condition short-circuited because condition_1 was True
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
def test_nested_compound_condition_too_many_nested_levels(
|
||||
rpc_condition, time_condition
|
||||
):
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
|
@ -297,24 +299,23 @@ def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
|
|||
operands=[
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
rpc_condition,
|
||||
AndCompoundCondition(
|
||||
operands=[
|
||||
condition_2,
|
||||
condition_3,
|
||||
time_condition,
|
||||
rpc_condition,
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
time_condition,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
def test_nested_sequential_condition_too_many_nested_levels(
|
||||
rpc_condition, time_condition
|
||||
):
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
|
@ -322,16 +323,16 @@ def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
|
|||
operands=[
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
rpc_condition,
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
ConditionVariable("var2", condition_2),
|
||||
ConditionVariable("var3", condition_3),
|
||||
ConditionVariable("var2", time_condition),
|
||||
ConditionVariable("var3", rpc_condition),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
time_condition,
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -55,6 +55,8 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def execute(self, providers: Dict, **context) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
EXECUTION_CALL_TYPE = FakeRPCCall
|
||||
|
||||
class Schema(ContractCondition.Schema):
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -63,9 +65,6 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ContractCall:
|
||||
return self.FakeRPCCall(*args, **kwargs)
|
||||
|
||||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_call.set_execution_return_value(value)
|
||||
|
||||
|
@ -142,7 +141,7 @@ def test_invalid_contract_condition():
|
|||
# invalid condition type
|
||||
with pytest.raises(
|
||||
InvalidCondition,
|
||||
match=f"must be instantiated with the {ConditionType.CONTRACT.value} type",
|
||||
match=f"'condition_type' field - Must be equal to {ConditionType.CONTRACT.value}",
|
||||
):
|
||||
_ = ContractCondition(
|
||||
condition_type=ConditionType.RPC.value,
|
||||
|
@ -167,7 +166,7 @@ def test_invalid_contract_condition():
|
|||
|
||||
# no abi or contract type
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="Provide 'standardContractType' or 'functionAbi'"
|
||||
InvalidCondition, match="Provide a standard contract type or function ABI"
|
||||
):
|
||||
_ = ContractCondition(
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
|
@ -206,7 +205,9 @@ def test_invalid_contract_condition():
|
|||
)
|
||||
|
||||
# method not in ABI
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="Could not find any function with matching name"
|
||||
):
|
||||
_ = ContractCondition(
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
method="getPolicy",
|
||||
|
@ -220,14 +221,55 @@ def test_invalid_contract_condition():
|
|||
|
||||
# standard contract type and function ABI
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="Provide 'standardContractType' or 'functionAbi'"
|
||||
InvalidCondition, match="Provide a standard contract type or function ABI"
|
||||
):
|
||||
_ = ContractCondition(
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
method="balanceOf",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
standard_contract_type="ERC20",
|
||||
function_abi={"rando": "ABI"},
|
||||
function_abi={
|
||||
"inputs": [
|
||||
{"internalType": "bytes16", "name": "_policyID", "type": "bytes16"}
|
||||
],
|
||||
"name": "getPolicy",
|
||||
"outputs": [
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"internalType": "address payable",
|
||||
"name": "sponsor",
|
||||
"type": "address",
|
||||
},
|
||||
{
|
||||
"internalType": "uint32",
|
||||
"name": "startTimestamp",
|
||||
"type": "uint32",
|
||||
},
|
||||
{
|
||||
"internalType": "uint32",
|
||||
"name": "endTimestamp",
|
||||
"type": "uint32",
|
||||
},
|
||||
{
|
||||
"internalType": "uint16",
|
||||
"name": "size",
|
||||
"type": "uint16",
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "owner",
|
||||
"type": "address",
|
||||
},
|
||||
],
|
||||
"internalType": "struct SubscriptionManager.Policy",
|
||||
"name": "",
|
||||
"type": "tuple",
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function",
|
||||
},
|
||||
return_value_test=ReturnValueTest("!=", 0),
|
||||
parameters=[
|
||||
":hrac",
|
||||
|
@ -250,17 +292,17 @@ def test_contract_condition_schema_validation():
|
|||
condition_dict = contract_condition.to_dict()
|
||||
|
||||
# no issues here
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
# no issues with optional name
|
||||
condition_dict["name"] = "my_contract_condition"
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no contract address defined
|
||||
condition_dict = contract_condition.to_dict()
|
||||
del condition_dict["contractAddress"]
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
balanceOf_abi = {
|
||||
"constant": True,
|
||||
|
@ -272,29 +314,29 @@ def test_contract_condition_schema_validation():
|
|||
"type": "function",
|
||||
}
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no function abi or standard contract type
|
||||
condition_dict = contract_condition.to_dict()
|
||||
del condition_dict["standardContractType"]
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# provide both function abi and standard contract type
|
||||
condition_dict = contract_condition.to_dict()
|
||||
condition_dict["functionAbi"] = balanceOf_abi
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
# remove standardContractType but specify function abi; no issues with that
|
||||
condition_dict = contract_condition.to_dict()
|
||||
del condition_dict["standardContractType"]
|
||||
condition_dict["functionAbi"] = balanceOf_abi
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no returnValueTest defined
|
||||
condition_dict = contract_condition.to_dict()
|
||||
del condition_dict["returnValueTest"]
|
||||
ContractCondition.validate(condition_dict)
|
||||
ContractCondition.from_dict(condition_dict)
|
||||
|
||||
|
||||
def test_contract_condition_repr(contract_condition_dict):
|
||||
|
@ -389,7 +431,9 @@ def test_abi_bool_output(contract_condition_dict):
|
|||
assert isinstance(contract_condition.return_value_test.value, bool)
|
||||
|
||||
# invalid type fails
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = 23
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -419,7 +463,8 @@ def test_abi_bool_output(contract_condition_dict):
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s), so only detected at decryption time
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=True,
|
||||
|
@ -438,7 +483,9 @@ def test_abi_uint_output(contract_condition_dict):
|
|||
assert isinstance(contract_condition.return_value_test.value, int)
|
||||
|
||||
# invalid type fails
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = True
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -468,7 +515,8 @@ def test_abi_uint_output(contract_condition_dict):
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s), so only detected at decryption time
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=123456789,
|
||||
|
@ -487,7 +535,9 @@ def test_abi_int_output(contract_condition_dict):
|
|||
assert isinstance(contract_condition.return_value_test.value, int)
|
||||
|
||||
# invalid type fails
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = [1, 2, 3]
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -517,7 +567,8 @@ def test_abi_int_output(contract_condition_dict):
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s), so only detected at decryption time
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=-123456789,
|
||||
|
@ -538,7 +589,9 @@ def test_abi_address_output(contract_condition_dict, get_random_checksum_address
|
|||
assert isinstance(contract_condition.return_value_test.value, str)
|
||||
|
||||
# invalid type fails
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = 1.25
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -569,7 +622,8 @@ def test_abi_address_output(contract_condition_dict, get_random_checksum_address
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s), so only detected at decryption time
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=checksum_address,
|
||||
|
@ -607,7 +661,9 @@ def test_abi_bytes_output(bytes_test_scenario, contract_condition_dict):
|
|||
assert isinstance(contract_condition.return_value_test.value, str)
|
||||
|
||||
# invalid type fails
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = 1.25
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -637,7 +693,8 @@ def test_abi_bytes_output(bytes_test_scenario, contract_condition_dict):
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s), so only detected at decryption time
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=call_result_in_bytes,
|
||||
|
@ -667,27 +724,37 @@ def test_abi_tuple_output(contract_condition_dict):
|
|||
assert isinstance(contract_condition.return_value_test.value, Sequence)
|
||||
|
||||
# 1. invalid type
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = 1
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
# 2. invalid number of values
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = [1, 2]
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
# 3a. Unmatched type
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = [True, 2, 3]
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
# 3b. Unmatched type
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = [1, False, 3]
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
# 3c. Unmatched type
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["value"] = [1, 2, 3.14159]
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
|
@ -725,7 +792,8 @@ def test_abi_tuple_output(contract_condition_dict):
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type(s) - boolean is unexpected in index 1
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=(1, 2, 3, random_bytes),
|
||||
|
@ -794,7 +862,9 @@ def test_abi_tuple_output_with_index(
|
|||
assert isinstance(contract_condition.return_value_test.value, str)
|
||||
|
||||
# invalid type at index
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
contract_condition_dict["returnValueTest"]["index"] = 0
|
||||
contract_condition_dict["returnValueTest"][
|
||||
"value"
|
||||
|
@ -827,7 +897,8 @@ def test_abi_tuple_output_with_index(
|
|||
)
|
||||
|
||||
# using index, test where context var has invalid expected type - unexpected type in index 2
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=tuple(result),
|
||||
|
@ -963,7 +1034,8 @@ def test_abi_multiple_output_values(
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=tuple(result),
|
||||
|
@ -998,8 +1070,9 @@ def test_abi_multiple_output_values(
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
comparator_value[0][0] = True # should be address but setting to bool
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=tuple(result),
|
||||
|
@ -1092,7 +1165,9 @@ def test_abi_nested_tuples_output_values(
|
|||
[1],
|
||||
get_random_checksum_address(), # missing tuple value
|
||||
]
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
contract_condition_dict["returnValueTest"]["value"] = [
|
||||
|
@ -1101,7 +1176,9 @@ def test_abi_nested_tuples_output_values(
|
|||
random_bytes_hex,
|
||||
get_random_checksum_address(), # incorrect tuple value for Timeframe
|
||||
]
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
contract_condition_dict["returnValueTest"]["value"] = [
|
||||
|
@ -1109,7 +1186,9 @@ def test_abi_nested_tuples_output_values(
|
|||
[1, random_bytes_hex, 3],
|
||||
get_random_checksum_address(), # too many values
|
||||
]
|
||||
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
|
||||
with pytest.raises(
|
||||
InvalidConditionLingo, match="Invalid return value comparison type"
|
||||
):
|
||||
ContractCondition.from_json(json.dumps(contract_condition_dict))
|
||||
|
||||
# process index 1 (bool)
|
||||
|
@ -1154,7 +1233,8 @@ def test_abi_nested_tuples_output_values(
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=tuple(result),
|
||||
|
@ -1188,8 +1268,9 @@ def test_abi_nested_tuples_output_values(
|
|||
)
|
||||
|
||||
# test where context var has invalid expected type
|
||||
# consequently this is not an invalid condition, but rather an incorrect context value
|
||||
comparator_value[0][2] = 1.25 # should be an address
|
||||
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
|
||||
with pytest.raises(ValueError, match="Mismatched comparator type"):
|
||||
_check_execution_logic(
|
||||
condition_dict=contract_condition_dict,
|
||||
execution_result=tuple(result),
|
||||
|
|
|
@ -25,11 +25,11 @@ def test_jsonpath_field_valid():
|
|||
def test_jsonpath_field_invalid():
|
||||
field = JSONPathField()
|
||||
invalid_jsonpath = "invalid jsonpath"
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match=f"'{invalid_jsonpath}' is not a valid JSONPath expression",
|
||||
):
|
||||
field.deserialize(invalid_jsonpath)
|
||||
assert f"'{invalid_jsonpath}' is not a valid JSONPath expression" in str(
|
||||
excinfo.value
|
||||
)
|
||||
|
||||
|
||||
def test_json_api_condition_initialization():
|
||||
|
@ -44,24 +44,23 @@ def test_json_api_condition_initialization():
|
|||
|
||||
|
||||
def test_json_api_condition_invalid_type():
|
||||
with pytest.raises(InvalidCondition) as excinfo:
|
||||
JsonApiCondition(
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="'condition_type' field - Must be equal to json-api"
|
||||
):
|
||||
_ = JsonApiCondition(
|
||||
endpoint="https://api.example.com/data",
|
||||
query="$.store.book[0].price",
|
||||
return_value_test=ReturnValueTest("==", 0),
|
||||
condition_type="INVALID_TYPE",
|
||||
)
|
||||
assert "must be instantiated with the json-api type" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_https_enforcement():
|
||||
with pytest.raises(InvalidCondition) as excinfo:
|
||||
JsonApiCondition(
|
||||
with pytest.raises(InvalidCondition, match="Not a valid URL"):
|
||||
_ = JsonApiCondition(
|
||||
endpoint="http://api.example.com/data",
|
||||
query="$.store.book[0].price",
|
||||
return_value_test=ReturnValueTest("==", 0),
|
||||
)
|
||||
assert "Not a valid URL" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_json_api_condition_with_primitive_response(mocker):
|
||||
|
@ -102,9 +101,8 @@ def test_json_api_condition_fetch_failure(mocker):
|
|||
query="$.store.book[0].price",
|
||||
return_value_test=ReturnValueTest("==", 1),
|
||||
)
|
||||
with pytest.raises(InvalidCondition) as excinfo:
|
||||
with pytest.raises(InvalidCondition, match="Failed to fetch endpoint"):
|
||||
condition.execution_call._fetch()
|
||||
assert "Failed to fetch endpoint" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_json_api_condition_verify(mocker):
|
||||
|
@ -160,9 +158,10 @@ def test_json_api_condition_verify_invalid_json(mocker):
|
|||
query="$.store.book[0].price",
|
||||
return_value_test=ReturnValueTest("==", 2),
|
||||
)
|
||||
with pytest.raises(ConditionEvaluationFailed) as excinfo:
|
||||
with pytest.raises(
|
||||
ConditionEvaluationFailed, match="Failed to parse JSON response"
|
||||
):
|
||||
condition.verify()
|
||||
assert "Failed to parse JSON response" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_non_json_response(mocker):
|
||||
|
@ -180,11 +179,11 @@ def test_non_json_response(mocker):
|
|||
return_value_test=ReturnValueTest("==", 18),
|
||||
)
|
||||
|
||||
with pytest.raises(ConditionEvaluationFailed) as excinfo:
|
||||
with pytest.raises(
|
||||
ConditionEvaluationFailed, match="Failed to parse JSON response"
|
||||
):
|
||||
condition.verify()
|
||||
|
||||
assert "Failed to parse JSON response" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_basic_json_api_condition_evaluation_with_parameters(mocker):
|
||||
mocked_get = mocker.patch(
|
||||
|
@ -243,7 +242,5 @@ def test_ambiguous_json_path_multiple_results(mocker):
|
|||
return_value_test=ReturnValueTest("==", 1),
|
||||
)
|
||||
|
||||
with pytest.raises(ConditionEvaluationFailed) as excinfo:
|
||||
with pytest.raises(ConditionEvaluationFailed, match="Ambiguous JSONPath query"):
|
||||
condition.verify()
|
||||
|
||||
assert "Ambiguous JSONPath query" in str(excinfo.value)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
@ -45,7 +48,7 @@ def test_invalid_rpc_condition():
|
|||
)
|
||||
|
||||
# unsupported chain id
|
||||
with pytest.raises(InvalidCondition, match="is not a permitted blockchain"):
|
||||
with pytest.raises(InvalidCondition, match="90210 is not a permitted blockchain"):
|
||||
_ = RPCCondition(
|
||||
method="eth_getBalance",
|
||||
chain=90210, # Beverly Hills Chain :)
|
||||
|
@ -54,10 +57,10 @@ def test_invalid_rpc_condition():
|
|||
)
|
||||
|
||||
# invalid chain type provided
|
||||
with pytest.raises(ValueError, match="must be the integer chain ID"):
|
||||
with pytest.raises(ValueError, match="invalid literal for int"):
|
||||
_ = RPCCondition(
|
||||
method="eth_getBalance",
|
||||
chain=str(TESTERCHAIN_CHAIN_ID), # should be int not str.
|
||||
chain="chainId", # should be int not str.
|
||||
return_value_test=ReturnValueTest("==", 0),
|
||||
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
|
||||
)
|
||||
|
@ -67,44 +70,44 @@ def test_rpc_condition_schema_validation(rpc_condition):
|
|||
condition_dict = rpc_condition.to_dict()
|
||||
|
||||
# no issues here
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
# no issues with optional name
|
||||
condition_dict["name"] = "my_rpc_condition"
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no chain defined
|
||||
condition_dict = rpc_condition.to_dict()
|
||||
del condition_dict["chain"]
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no method defined
|
||||
condition_dict = rpc_condition.to_dict()
|
||||
del condition_dict["method"]
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
# no issue with no parameters
|
||||
condition_dict = rpc_condition.to_dict()
|
||||
del condition_dict["parameters"]
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no returnValueTest defined
|
||||
condition_dict = rpc_condition.to_dict()
|
||||
del condition_dict["returnValueTest"]
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# chain id not an integer
|
||||
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# chain id not a permitted chain
|
||||
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
|
||||
RPCCondition.validate(condition_dict)
|
||||
RPCCondition.from_dict(condition_dict)
|
||||
|
||||
|
||||
def test_rpc_condition_repr(rpc_condition):
|
||||
|
|
|
@ -38,15 +38,15 @@ def mock_condition_variables(mocker):
|
|||
return var_1, var_2, var_3, var_4
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_invalid_sequential_condition(mock_condition_variables):
|
||||
var_1, *_ = mock_condition_variables
|
||||
def test_invalid_sequential_condition(rpc_condition, time_condition):
|
||||
var_1 = ConditionVariable("var1", time_condition)
|
||||
var_2 = ConditionVariable("var2", rpc_condition)
|
||||
|
||||
# invalid condition type
|
||||
with pytest.raises(InvalidCondition, match=ConditionType.SEQUENTIAL.value):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_type=ConditionType.TIME.value,
|
||||
condition_variables=list(mock_condition_variables),
|
||||
condition_variables=[var_1, var_2],
|
||||
)
|
||||
|
||||
# no variables
|
||||
|
@ -62,18 +62,29 @@ def test_invalid_sequential_condition(mock_condition_variables):
|
|||
)
|
||||
|
||||
# too many variables
|
||||
too_many_variables = list(mock_condition_variables)
|
||||
too_many_variables.extend(mock_condition_variables) # duplicate list length
|
||||
too_many_variables = [var_1, var_2, var_1, var_2]
|
||||
too_many_variables.extend(too_many_variables) # duplicate list length
|
||||
assert len(too_many_variables) > SequentialAccessControlCondition.MAX_NUM_CONDITIONS
|
||||
with pytest.raises(InvalidCondition, match="Maximum of"):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=too_many_variables,
|
||||
)
|
||||
|
||||
# duplicate var names
|
||||
dupe_var = ConditionVariable(var_1.var_name, condition=var_2.condition)
|
||||
with pytest.raises(InvalidCondition, match="Duplicate"):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=[var_1, var_2, dupe_var],
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_sequential_condition_too_many_nested_levels(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
def test_nested_sequential_condition_too_many_nested_levels(
|
||||
rpc_condition, time_condition
|
||||
):
|
||||
var_1 = ConditionVariable("var1", time_condition)
|
||||
var_2 = ConditionVariable("var2", rpc_condition)
|
||||
var_3 = ConditionVariable("var3", time_condition)
|
||||
var_4 = ConditionVariable("var4", rpc_condition)
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
|
@ -104,9 +115,13 @@ def test_nested_sequential_condition_too_many_nested_levels(mock_condition_varia
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_compound_condition_too_many_nested_levels(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
def test_nested_compound_condition_too_many_nested_levels(
|
||||
rpc_condition, time_condition
|
||||
):
|
||||
var_1 = ConditionVariable("var1", time_condition)
|
||||
var_2 = ConditionVariable("var2", rpc_condition)
|
||||
var_3 = ConditionVariable("var3", time_condition)
|
||||
var_4 = ConditionVariable("var4", rpc_condition)
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import pytest
|
||||
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
@ -37,38 +40,38 @@ def test_time_condition_schema_validation(time_condition):
|
|||
condition_dict = time_condition.to_dict()
|
||||
|
||||
# no issues here
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
# no issues with optional name
|
||||
condition_dict["name"] = "my_time_machine"
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no method
|
||||
condition_dict = time_condition.to_dict()
|
||||
del condition_dict["method"]
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# no returnValueTest defined
|
||||
condition_dict = time_condition.to_dict()
|
||||
del condition_dict["returnValueTest"]
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# invalid method name
|
||||
condition_dict["method"] = "my_blocktime"
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# chain id not an integer
|
||||
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
# chain id not a permitted chain
|
||||
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
|
||||
TimeCondition.validate(condition_dict)
|
||||
TimeCondition.from_dict(condition_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue