Merge pull request #3556 from derekpierre/cleanup-validation

Clean up/Simplify condition validation
pull/3563/head
Derek Pierre 2024-10-01 13:42:33 -04:00 committed by GitHub
commit 34842fd88c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 778 additions and 473 deletions

View File

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from base64 import b64decode, b64encode
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from marshmallow import Schema, ValidationError, fields
@ -9,7 +9,10 @@ from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.utils import CamelCaseSchema
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
extract_single_error_message_from_schema_errors,
)
class _Serializable:
@ -56,62 +59,49 @@ class AccessControlCondition(_Serializable, ABC):
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
name = fields.Str(required=False, allow_none=True)
condition_type = NotImplemented
def __init__(self, condition_type: str, name: Optional[str] = None):
super().__init__()
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.name = name
# validate inputs using marshmallow schema
self.validate(self.to_dict())
self._validate()
@abstractmethod
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
"""Returns the boolean result of the evaluation and the returned value in a two-tuple."""
raise NotImplementedError
@classmethod
def validate(cls, data: Dict) -> None:
errors = cls.Schema().validate(data=data)
def _validate(self, **kwargs):
errors = self.Schema().validate(data=self.to_dict())
if errors:
raise InvalidCondition(f"Invalid {cls.__name__}: {errors}")
error_message = extract_single_error_message_from_schema_errors(errors)
raise InvalidCondition(
f"Invalid {self.__class__.__name__}: {error_message}"
)
@classmethod
def from_dict(cls, data) -> "AccessControlCondition":
try:
return super().from_dict(data)
except ValidationError as e:
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
raise InvalidConditionLingo(f"Invalid condition grammar: {e}") from e
@classmethod
def from_json(cls, data) -> "AccessControlCondition":
try:
return super().from_json(data)
except ValidationError as e:
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
class ExecutionCall(ABC):
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
raise NotImplementedError
raise InvalidConditionLingo(f"Invalid condition grammar: {e}") from e
class MultiConditionAccessControl(AccessControlCondition):
MAX_NUM_CONDITIONS = 5
MAX_MULTI_CONDITION_NESTED_LEVEL = 2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._validate_multi_condition_nesting(conditions=self.conditions)
@property
@abstractmethod
def conditions(self) -> List[AccessControlCondition]:
@ -121,11 +111,13 @@ class MultiConditionAccessControl(AccessControlCondition):
def _validate_multi_condition_nesting(
cls,
conditions: List[AccessControlCondition],
field_name: str,
current_level: int = 1,
):
if len(conditions) > cls.MAX_NUM_CONDITIONS:
raise InvalidCondition(
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
raise ValidationError(
field_name=field_name,
message=f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed",
)
for condition in conditions:
@ -134,10 +126,31 @@ class MultiConditionAccessControl(AccessControlCondition):
level = current_level + 1
if level > cls.MAX_MULTI_CONDITION_NESTED_LEVEL:
raise InvalidCondition(
f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed"
raise ValidationError(
field_name=field_name,
message=f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed",
)
cls._validate_multi_condition_nesting(
conditions=condition.conditions,
field_name=field_name,
current_level=level,
)
class ExecutionCall(_Serializable, ABC):
class InvalidExecutionCall(ValueError):
pass
class Schema(CamelCaseSchema):
pass
def __init__(self):
# validate call using marshmallow schema before creating
errors = self.Schema().validate(data=self.to_dict())
if errors:
error_message = extract_single_error_message_from_schema_errors(errors)
raise self.InvalidExecutionCall(f"{error_message}")
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
raise NotImplementedError

View File

@ -10,14 +10,22 @@ from typing import (
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
from marshmallow import (
ValidationError,
fields,
post_load,
validate,
validates,
validates_schema,
)
from marshmallow.validate import OneOf
from typing_extensions import override
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
from web3.providers import BaseProvider
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
from nucypher.policy.conditions.base import (
ExecutionCall,
)
@ -26,7 +34,6 @@ from nucypher.policy.conditions.context import (
resolve_parameter_context_variables,
)
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
@ -38,11 +45,10 @@ from nucypher.policy.conditions.lingo import (
)
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.policy.conditions.validation import (
_align_comparator_value_with_abi,
_get_abi_types,
_validate_contract_call_abi,
_validate_multiple_output_types,
_validate_single_output_type,
align_comparator_value_with_abi,
get_unbound_contract_function,
validate_contract_function_expected_return_type,
validate_function_abi,
)
# TODO: Move this to a more appropriate location,
@ -61,53 +67,6 @@ _CONDITION_CHAINS = {
}
def _resolve_abi(
w3: Web3,
method: str,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
) -> ABIFunction:
"""Resolves the contract an/or function ABI from a standard contract name"""
if not (function_abi or standard_contract_type):
raise InvalidCondition(
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
)
if standard_contract_type:
try:
# Lookup the standard ABI given it's ERC standard name (standard contract type)
contract_abi = STANDARD_ABIS[standard_contract_type]
except KeyError:
raise InvalidCondition(
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
)
try:
# Extract all function ABIs from the contract's ABI.
# Will raise a ValueError if there is not exactly one match.
function_abi = (
w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
)
except ValueError as e:
raise InvalidCondition(str(e))
return ABIFunction(function_abi)
def _validate_chain(chain: int) -> None:
if not isinstance(chain, int):
raise ValueError(
f'The "chain" field of a condition must be the '
f'integer chain ID (got "{chain}").'
)
if chain not in _CONDITION_CHAINS:
raise InvalidCondition(
f"chain ID {chain} is not a permitted "
f"blockchain for condition evaluation."
)
class RPCCall(ExecutionCall):
LOG = logging.Logger(__name__)
@ -116,28 +75,47 @@ class RPCCall(ExecutionCall):
"eth_getBalance": int,
} # TODO other allowed methods (tDEC #64)
class Schema(ExecutionCall.Schema):
chain = fields.Int(required=True, strict=True)
method = fields.Str(
required=True,
error_messages={
"required": "Undefined method name",
"null": "Undefined method name",
},
)
parameters = fields.List(
fields.Field, attribute="parameters", required=False, allow_none=True
)
@validates("chain")
def validate_chain(self, value):
if value not in _CONDITION_CHAINS:
raise ValidationError(
f"chain ID {value} is not a permitted blockchain for condition evaluation"
)
@validates("method")
def validate_method(self, value):
if value not in RPCCall.ALLOWED_METHODS:
raise ValidationError(
f"'{value}' is not a permitted RPC endpoint for condition evaluation."
)
@post_load
def make(self, data, **kwargs):
return RPCCall(**data)
def __init__(
self,
chain: int,
method: str,
parameters: Optional[List[Any]] = None,
):
# Validate input
_validate_chain(chain=chain)
self.chain = chain
self.method = self._validate_method(method=method)
self.parameters = parameters or None
def _validate_method(self, method):
if not method:
raise ValueError("Undefined method name")
if method not in self.ALLOWED_METHODS:
raise ValueError(
f"'{method}' is not a permitted RPC endpoint for condition evaluation."
)
return method
self.method = method
self.parameters = parameters
super().__init__()
def _get_web3_py_function(self, w3: Web3, rpc_method: str):
web3_py_method = camel_case_to_snake(rpc_method)
@ -226,17 +204,30 @@ class RPCCall(ExecutionCall):
class RPCCondition(ExecutionCallAccessControlCondition):
EXECUTION_CALL_TYPE = RPCCall
CONDITION_TYPE = ConditionType.RPC.value
class Schema(ExecutionCallAccessControlCondition.Schema):
class Schema(ExecutionCallAccessControlCondition.Schema, RPCCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
@validates_schema()
def validate_expected_return_type(self, data, **kwargs):
method = data.get("method")
return_value_test = data.get("return_value_test")
expected_return_type = RPCCall.ALLOWED_METHODS[method]
comparator_value = return_value_test.value
if is_context_variable(comparator_value):
return
if not isinstance(return_value_test.value, expected_return_type):
raise ValidationError(
field_name="return_value_test",
message=f"Return value comparison for '{method}' call output "
f"should be '{expected_return_type}' and not '{type(comparator_value)}'.",
)
@post_load
def make(self, data, **kwargs):
@ -248,16 +239,25 @@ class RPCCondition(ExecutionCallAccessControlCondition):
def __init__(
self,
condition_type: str = CONDITION_TYPE,
chain: int,
method: str,
return_value_test: ReturnValueTest,
condition_type: str = ConditionType.RPC.value,
name: Optional[str] = None,
parameters: Optional[List[Any]] = None,
*args,
**kwargs,
):
super().__init__(condition_type=condition_type, *args, **kwargs)
self._validate_expected_return_type()
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return RPCCall(*args, **kwargs)
super().__init__(
chain=chain,
method=method,
return_value_test=return_value_test,
condition_type=condition_type,
name=name,
parameters=parameters,
*args,
**kwargs,
)
@property
def method(self):
@ -271,18 +271,6 @@ class RPCCondition(ExecutionCallAccessControlCondition):
def parameters(self):
return self.execution_call.parameters
def _validate_expected_return_type(self):
expected_return_type = RPCCall.ALLOWED_METHODS[self.method]
comparator_value = self.return_value_test.value
if is_context_variable(comparator_value):
return
if not isinstance(self.return_value_test.value, expected_return_type):
raise InvalidCondition(
f"Return value comparison for '{self.method}' call output "
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
)
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:
@ -297,6 +285,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test
)
result = self.execution_call.execute(providers=providers, **context)
eval_result = return_value_test.eval(result) # test
@ -304,6 +293,79 @@ class RPCCondition(ExecutionCallAccessControlCondition):
class ContractCall(RPCCall):
class Schema(RPCCall.Schema):
contract_address = fields.Str(required=True)
standard_contract_type = fields.Str(
required=False,
validate=OneOf(
STANDARD_ABI_CONTRACT_TYPES,
error="Invalid standard contract type: {input}",
),
allow_none=True,
)
function_abi = fields.Dict(required=False, allow_none=True)
@post_load
def make(self, data, **kwargs):
return ContractCall(**data)
@validates("contract_address")
def validate_contract_address(self, value):
try:
to_checksum_address(value)
except ValueError:
raise ValidationError(f"Invalid checksum address: '{value}'")
@override
@validates("method")
def validate_method(self, value):
return
@validates("function_abi")
def validate_abi(self, value):
# needs to be done before schema validation
if value:
try:
validate_function_abi(value)
except ValueError as e:
raise ValidationError(
field_name="function_abi", message=str(e)
) from e
@validates_schema
def validate_standard_contract_type_or_function_abi(self, data, **kwargs):
method = data.get("method")
standard_contract_type = data.get("standard_contract_type")
function_abi = data.get("function_abi")
# validate xor of standard contract type and function abi
if not (bool(standard_contract_type) ^ bool(function_abi)):
raise ValidationError(
field_name="standard_contract_type",
message=f"Provide a standard contract type or function ABI; got ({standard_contract_type}, {function_abi}).",
)
# validate function abi with method name (not available for field validation)
if function_abi:
try:
validate_function_abi(function_abi, method_name=method)
except ValueError as e:
raise ValidationError(
field_name="function_abi", message=str(e)
) from e
# validate contract
contract_address = to_checksum_address(data.get("contract_address"))
try:
get_unbound_contract_function(
contract_address=contract_address,
method=method,
standard_contract_type=standard_contract_type,
function_abi=function_abi,
)
except ValueError as e:
raise ValidationError(str(e)) from e
def __init__(
self,
method: str,
@ -313,13 +375,6 @@ class ContractCall(RPCCall):
*args,
**kwargs,
):
if not method:
raise ValueError("Undefined method name")
_validate_contract_call_abi(
standard_contract_type, function_abi, method_name=method
)
# preprocessing
contract_address = to_checksum_address(contract_address)
self.contract_address = contract_address
@ -327,30 +382,14 @@ class ContractCall(RPCCall):
self.function_abi = function_abi
super().__init__(method=method, *args, **kwargs)
self.contract_function = self._get_unbound_contract_function()
def _validate_method(self, method):
return method
def _get_unbound_contract_function(self) -> ContractFunction:
"""Gets an unbound contract function to evaluate for this condition"""
w3 = Web3()
function_abi = _resolve_abi(
w3=w3,
standard_contract_type=self.standard_contract_type,
# contract function already validated - so should not raise an exception
self.contract_function = get_unbound_contract_function(
contract_address=self.contract_address,
method=self.method,
standard_contract_type=self.standard_contract_type,
function_abi=self.function_abi,
)
try:
contract = w3.eth.contract(
address=self.contract_address, abi=[function_abi]
)
contract_function = getattr(contract.functions, self.method)
return contract_function
except Exception as e:
raise ValueError(
f"Unable to find contract function, '{self.method}', for condition: {e}"
)
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
@ -363,42 +402,63 @@ class ContractCall(RPCCall):
class ContractCondition(RPCCondition):
EXECUTION_CALL_TYPE = ContractCall
CONDITION_TYPE = ConditionType.CONTRACT.value
class Schema(RPCCondition.Schema):
class Schema(RPCCondition.Schema, ContractCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
)
contract_address = fields.Str(required=True)
standard_contract_type = fields.Str(required=False)
function_abi = fields.Dict(required=False)
@validates_schema()
def validate_expected_return_type(self, data, **kwargs):
# validate that contract function is correct
try:
contract_function = get_unbound_contract_function(
contract_address=data.get("contract_address"),
method=data.get("method"),
standard_contract_type=data.get("standard_contract_type"),
function_abi=data.get("function_abi"),
)
except ValueError as e:
raise ValidationError(str(e)) from e
# validate return type based on contract function
return_value_test = data.get("return_value_test")
try:
validate_contract_function_expected_return_type(
contract_function=contract_function,
return_value_test=return_value_test,
)
except ValueError as e:
raise ValidationError(
field_name="return_value_test",
message=str(e),
) from e
@post_load
def make(self, data, **kwargs):
return ContractCondition(**data)
@validates_schema
def check_standard_contract_type_or_function_abi(self, data, **kwargs):
standard_contract_type = data.get("standard_contract_type")
function_abi = data.get("function_abi")
try:
_validate_contract_call_abi(
standard_contract_type, function_abi, method_name=data.get("method")
)
except ValueError as e:
raise ValidationError(str(e))
def __init__(
self,
condition_type: str = CONDITION_TYPE,
method: str,
contract_address: ChecksumAddress,
condition_type: str = ConditionType.CONTRACT.value,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
**kwargs,
):
# call to super must be at the end for proper validation
super().__init__(condition_type=condition_type, *args, **kwargs)
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return ContractCall(*args, **kwargs)
super().__init__(
method=method,
condition_type=condition_type,
contract_address=contract_address,
standard_contract_type=standard_contract_type,
function_abi=function_abi,
*args,
**kwargs,
)
@property
def function_abi(self):
@ -416,12 +476,6 @@ class ContractCondition(RPCCondition):
def contract_address(self):
return self.execution_call.contract_address
def _validate_expected_return_type(self) -> None:
_validate_contract_function_expected_return_type(
contract_function=self.contract_function,
return_value_test=self.return_value_test,
)
def __repr__(self) -> str:
r = (
f"{self.__class__.__name__}(function={self.method}, "
@ -433,29 +487,7 @@ class ContractCondition(RPCCondition):
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:
return _align_comparator_value_with_abi(
return align_comparator_value_with_abi(
abi=self.contract_function.contract_abi[0],
return_value_test=return_value_test,
)
def _validate_contract_function_expected_return_type(
contract_function: ContractFunction, return_value_test: ReturnValueTest
) -> None:
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
comparator_value = return_value_test.value
comparator_index = return_value_test.index
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)

View File

@ -2,10 +2,9 @@ import ast
import base64
import json
import operator as pyoperator
from abc import abstractmethod
from enum import Enum
from hashlib import md5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from hexbytes import HexBytes
from marshmallow import (
@ -103,25 +102,29 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
def _validate_operator_and_operands(
cls,
operator: str,
operands: List[Union[Dict, AccessControlCondition]],
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
operands: List[AccessControlCondition],
):
if operator not in cls.OPERATORS:
raise exception_class(f"{operator} is not a valid operator")
raise ValidationError(
field_name="operator", message=f"{operator} is not a valid operator"
)
num_operands = len(operands)
if operator == cls.NOT_OPERATOR:
if num_operands != 1:
raise exception_class(
f"Only 1 operand permitted for '{operator}' compound condition"
raise ValidationError(
field_name="operands",
message=f"Only 1 operand permitted for '{operator}' compound condition",
)
elif num_operands < 2:
raise exception_class(
f"Minimum of 2 operand needed for '{operator}' compound condition"
raise ValidationError(
field_name="operands",
message=f"Minimum of 2 operand needed for '{operator}' compound condition",
)
elif num_operands > cls.MAX_NUM_CONDITIONS:
raise exception_class(
f"Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition"
raise ValidationError(
field_name="operands",
message="Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition",
)
@ -141,7 +144,10 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
operator = data["operator"]
operands = data["operands"]
CompoundAccessControlCondition._validate_operator_and_operands(
operator, operands, ValidationError
operator, operands
)
CompoundAccessControlCondition._validate_multi_condition_nesting(
conditions=operands, field_name="operands"
)
@post_load
@ -161,15 +167,15 @@ class CompoundAccessControlCondition(MultiConditionAccessControl):
"operands": [CONDITION*]
}
"""
self._validate_operator_and_operands(operator, operands, InvalidCondition)
self.operator = operator
self.operands = operands
self.condition_type = condition_type
self.name = name
self.id = md5(bytes(self)).hexdigest()[:6]
super().__init__(condition_type=condition_type, name=name)
super().__init__(
condition_type=condition_type,
name=name,
)
self.id = md5(bytes(self)).hexdigest()[:6]
def __repr__(self):
return f"Operator={self.operator} (NumOperands={len(self.operands)}), id={self.id})"
@ -265,18 +271,30 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
@classmethod
def _validate_condition_variables(
cls,
condition_variables: List[Union[Dict, ConditionVariable]],
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
condition_variables: List[ConditionVariable],
):
num_condition_variables = len(condition_variables)
if num_condition_variables < 2:
raise exception_class("At least two conditions must be specified")
if num_condition_variables > cls.MAX_NUM_CONDITIONS:
raise exception_class(
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
if not condition_variables or len(condition_variables) < 2:
raise ValidationError(
field_name="condition_variables",
message="At least two conditions must be specified",
)
if len(condition_variables) > cls.MAX_NUM_CONDITIONS:
raise ValidationError(
field_name="condition_variables",
message=f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed",
)
# check for duplicate var names
var_names = set()
for condition_variable in condition_variables:
if condition_variable.var_name in var_names:
raise ValidationError(
field_name="condition_variables",
message=f"Duplicate variable names are not allowed - {condition_variable.var_name}",
)
var_names.add(condition_variable.var_name)
class Schema(AccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
@ -289,11 +307,12 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
class Meta:
ordered = True
@validates_schema
def validate_condition_variables(self, data, **kwargs):
condition_variables = data["condition_variables"]
SequentialAccessControlCondition._validate_condition_variables(
condition_variables, ValidationError
@validates("condition_variables")
def validate_condition_variables(self, value):
SequentialAccessControlCondition._validate_condition_variables(value)
conditions = [cv.condition for cv in value]
SequentialAccessControlCondition._validate_multi_condition_nesting(
conditions=conditions, field_name="condition_variables"
)
@post_load
@ -306,11 +325,11 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
self._validate_condition_variables(
condition_variables=condition_variables, exception_class=InvalidCondition
)
self.condition_variables = condition_variables
super().__init__(condition_type=condition_type, name=name)
super().__init__(
condition_type=condition_type,
name=name,
)
def __repr__(self):
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
@ -358,7 +377,9 @@ class ReturnValueTest:
value = fields.Raw(
allow_none=False, required=True
) # any valid type (excludes None)
index = fields.Int(strict=True, required=False, validate=Range(min=0))
index = fields.Int(
strict=True, required=False, validate=Range(min=0), allow_none=True
)
@post_load
def make(self, data, **kwargs):
@ -574,6 +595,8 @@ class ExecutionCallAccessControlCondition(AccessControlCondition):
Conditions that utilize underlying ExecutionCall objects.
"""
EXECUTION_CALL_TYPE = NotImplemented
class Schema(AccessControlCondition.Schema):
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
@ -588,19 +611,14 @@ class ExecutionCallAccessControlCondition(AccessControlCondition):
**kwargs,
):
self.return_value_test = return_value_test
try:
self.execution_call = self._create_execution_call(*args, **kwargs)
except ValueError as e:
raise InvalidCondition(str(e))
self.execution_call = self.EXECUTION_CALL_TYPE(*args, **kwargs)
except ExecutionCall.InvalidExecutionCall as e:
raise InvalidCondition(str(e)) from e
super().__init__(condition_type=condition_type, name=name)
@abstractmethod
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
"""
Returns the execution call that the condition executes.
"""
raise NotImplementedError
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
"""

View File

@ -14,6 +14,7 @@ from nucypher.policy.conditions.exceptions import (
from nucypher.policy.conditions.lingo import (
ConditionType,
ExecutionCallAccessControlCondition,
ReturnValueTest,
)
from nucypher.utilities.logging import Logger
@ -37,6 +38,15 @@ class JSONPathField(Field):
class JsonApiCall(ExecutionCall):
TIMEOUT = 5 # seconds
class Schema(ExecutionCall.Schema):
endpoint = Url(required=True, relative=False, schemes=["https"])
parameters = fields.Dict(required=False, allow_none=True)
query = JSONPathField(required=False, allow_none=True)
@post_load
def make(self, data, **kwargs):
return JsonApiCall(**data)
def __init__(
self,
endpoint: str,
@ -50,6 +60,8 @@ class JsonApiCall(ExecutionCall):
self.timeout = self.TIMEOUT
self.logger = Logger(__name__)
super().__init__()
def execute(self, *args, **kwargs) -> Any:
response = self._fetch()
data = self._deserialize_response(response)
@ -129,15 +141,13 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
The response will be deserialized as JSON and parsed using jsonpath.
"""
EXECUTION_CALL_TYPE = JsonApiCall
CONDITION_TYPE = ConditionType.JSONAPI.value
class Schema(ExecutionCallAccessControlCondition.Schema):
class Schema(ExecutionCallAccessControlCondition.Schema, JsonApiCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
)
endpoint = Url(required=True, relative=False, schemes=["https"])
parameters = fields.Dict(required=False, allow_none=True)
query = JSONPathField(required=False, allow_none=True)
@post_load
def make(self, data, **kwargs):
@ -145,14 +155,21 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
def __init__(
self,
endpoint: str,
return_value_test: ReturnValueTest,
query: Optional[str] = None,
parameters: Optional[dict] = None,
condition_type: str = ConditionType.JSONAPI.value,
*args,
**kwargs,
name: Optional[str] = None,
):
super().__init__(condition_type=condition_type, *args, **kwargs)
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return JsonApiCall(*args, **kwargs)
super().__init__(
endpoint=endpoint,
return_value_test=return_value_test,
query=query,
parameters=parameters,
condition_type=condition_type,
name=name,
)
@property
def endpoint(self):

View File

@ -1,36 +1,51 @@
from typing import Any, List, Optional
from marshmallow import fields, post_load, validate
from marshmallow.validate import Equal
from marshmallow import (
ValidationError,
fields,
post_load,
validate,
validates,
validates_schema,
)
from typing_extensions import override
from web3 import Web3
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ConditionType
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
class TimeRPCCall(RPCCall):
METHOD = "blocktime"
class Schema(RPCCall.Schema):
method = fields.Str(dump_default="blocktime", required=True)
@override
@validates("method")
def validate_method(self, value):
if value != TimeRPCCall.METHOD:
raise ValidationError(f"method name must be {TimeRPCCall.METHOD}.")
@validates("parameters")
def validate_no_parameters(self, value):
if value:
raise ValidationError(
f"{TimeRPCCall.METHOD}' does not take any parameters"
)
@post_load
def make(self, data, **kwargs):
return TimeRPCCall(**data)
def __init__(
self,
chain: int,
method: str = METHOD,
parameters: Optional[List[Any]] = None,
):
if parameters:
raise ValueError(f"{self.METHOD} does not take any parameters")
super().__init__(chain=chain, method=method, parameters=parameters)
def _validate_method(self, method):
if method != self.METHOD:
raise ValueError(
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
)
return method
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
# TODO may need to rethink as part of #3051 (multicall work).
@ -39,15 +54,23 @@ class TimeRPCCall(RPCCall):
class TimeCondition(RPCCondition):
EXECUTION_CALL_TYPE = TimeRPCCall
CONDITION_TYPE = ConditionType.TIME.value
class Schema(RPCCondition.Schema):
class Schema(RPCCondition.Schema, TimeRPCCall.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.TIME.value), required=True
)
method = fields.Str(
dump_default="blocktime", required=True, validate=Equal("blocktime")
)
@validates_schema
def validate_expected_return_type(self, data, **kwargs):
return_value_test = data.get("return_value_test")
comparator_value = return_value_test.value
if not isinstance(comparator_value, int):
raise ValidationError(
field_name="return_value_test",
message=f"Invalid return value comparison type '{type(comparator_value)}'; must be an integer",
)
@post_load
def make(self, data, **kwargs):
@ -59,29 +82,21 @@ class TimeCondition(RPCCondition):
def __init__(
self,
return_value_test: ReturnValueTest,
chain: int,
method: str = TimeRPCCall.METHOD,
condition_type: str = CONDITION_TYPE,
*args,
**kwargs,
condition_type: str = ConditionType.TIME.value,
name: Optional[str] = None,
):
# call to super must be at the end for proper validation
super().__init__(
condition_type=condition_type,
return_value_test=return_value_test,
chain=chain,
method=method,
*args,
**kwargs,
condition_type=condition_type,
name=name,
)
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return TimeRPCCall(*args, **kwargs)
def _validate_expected_return_type(self):
comparator_value = self.return_value_test.value
if not isinstance(comparator_value, int):
raise InvalidCondition(
f"Invalid return value comparison type '{type(comparator_value)}'; must be an integer"
)
@property
def timestamp(self):
return self.return_value_test.value

View File

@ -1,8 +1,9 @@
import re
from http import HTTPStatus
from typing import Dict, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple
from marshmallow import Schema, post_dump
from marshmallow.exceptions import SCHEMA
from web3.providers import BaseProvider
from nucypher.policy.conditions.exceptions import (
@ -138,3 +139,33 @@ def evaluate_condition_lingo(
if error:
log.info(error.message) # log error message
raise error
def extract_single_error_message_from_schema_errors(
errors: Dict[str, List[str]]
) -> str:
"""
Extract single error message from Schema().validate() errors result.
The result is only for a single error type, and only the first message string for that type.
If there are multiple error types, only one error type is used; the first field-specific (@validates)
error type encountered is prioritized over any schema-level-specific (@validates_schema) error.
"""
if not errors:
raise ValueError("Validation errors must be provided")
# extract error type - either field-specific (preferred) or schema-specific
error_key_to_use = None
for error_type in list(errors.keys()):
error_key_to_use = error_type
if error_key_to_use != SCHEMA:
# actual field
break
message = errors[error_key_to_use][0]
message_prefix = (
f"'{camel_case_to_snake(error_key_to_use)}' field - "
if error_key_to_use != SCHEMA
else ""
)
return f"{message_prefix}{message}"

View File

@ -7,42 +7,24 @@ from typing import (
cast,
)
from eth_typing import ChecksumAddress
from web3 import Web3
from web3.auto import w3
from web3.contract.contract import ContractFunction
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
from nucypher.policy.conditions.context import is_context_variable
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
)
from nucypher.policy.conditions.lingo import ReturnValueTest
def _validate_single_output_type(
expected_type: str,
comparator_value: Any,
comparator_index: Optional[int],
failure_message: str,
) -> None:
if comparator_index is not None and _is_tuple_type(expected_type):
type_entries = _get_tuple_type_entries(expected_type)
expected_type = type_entries[comparator_index]
_validate_value_type(expected_type, comparator_value, failure_message)
#
# Schema logic
#
def _get_abi_types(abi: ABIFunction) -> List[str]:
return [_collapse_if_tuple(cast(Dict[str, Any], arg)) for arg in abi["outputs"]]
def _validate_value_type(
expected_type: str, comparator_value: Any, failure_message: str
) -> None:
if is_context_variable(comparator_value):
# context variable types cannot be known until execution time.
return
if not w3.is_encodable(expected_type, comparator_value):
raise InvalidCondition(failure_message)
def _collapse_if_tuple(abi: Dict[str, Any]) -> str:
abi_type = abi["type"]
if not abi_type.startswith("tuple"):
@ -66,6 +48,28 @@ def _get_tuple_type_entries(tuple_type: str) -> List[str]:
return result
def _validate_value_type(
expected_type: str, comparator_value: Any, failure_message: str
) -> None:
if is_context_variable(comparator_value):
# context variable types cannot be known until execution time.
return
if not w3.is_encodable(expected_type, comparator_value):
raise ValueError(failure_message)
def _validate_single_output_type(
expected_type: str,
comparator_value: Any,
comparator_index: Optional[int],
failure_message: str,
) -> None:
if comparator_index is not None and _is_tuple_type(expected_type):
type_entries = _get_tuple_type_entries(expected_type)
expected_type = type_entries[comparator_index]
_validate_value_type(expected_type, comparator_value, failure_message)
def _validate_multiple_output_types(
output_abi_types: List[str],
comparator_value: Any,
@ -82,15 +86,46 @@ def _validate_multiple_output_types(
return
if not isinstance(comparator_value, Sequence):
raise InvalidCondition(failure_message)
raise ValueError(failure_message)
if len(output_abi_types) != len(comparator_value):
raise InvalidCondition(failure_message)
raise ValueError(failure_message)
for output_abi_type, component_value in zip(output_abi_types, comparator_value):
_validate_value_type(output_abi_type, component_value, failure_message)
def _resolve_abi(
w3: Web3,
method: str,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
) -> ABIFunction:
"""Resolves the contract an/or function ABI from a standard contract name"""
if not (function_abi or standard_contract_type):
raise ValueError(
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
)
if standard_contract_type:
try:
# Lookup the standard ABI given it's ERC standard name (standard contract type)
contract_abi = STANDARD_ABIS[standard_contract_type]
except KeyError:
raise ValueError(
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
)
# Extract all function ABIs from the contract's ABI.
# Will raise a ValueError if there is not exactly one match.
function_abi = (
w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
)
return ABIFunction(function_abi)
def _align_comparator_value_single_output(
expected_type: str, comparator_value: Any, comparator_index: Optional[int]
) -> Any:
@ -99,7 +134,7 @@ def _align_comparator_value_single_output(
expected_type = type_entries[comparator_index]
if not w3.is_encodable(expected_type, comparator_value):
raise InvalidCondition(
raise ValueError(
f"Mismatched comparator type ({comparator_value} as {expected_type})"
)
return comparator_value
@ -112,7 +147,7 @@ def _align_comparator_value_multiple_output(
expected_type = output_abi_types[comparator_index]
# ensure alignment
if not w3.is_encodable(expected_type, comparator_value):
raise InvalidCondition(
raise ValueError(
f"Mismatched comparator type ({comparator_value} as {expected_type})"
)
@ -122,15 +157,20 @@ def _align_comparator_value_multiple_output(
for output_abi_type, component_value in zip(output_abi_types, comparator_value):
# ensure alignment
if not w3.is_encodable(output_abi_type, component_value):
raise InvalidCondition(
raise ValueError(
f"Mismatched comparator type ({component_value} as {output_abi_type})"
)
values.append(component_value)
return values
def _align_comparator_value_with_abi(
abi, return_value_test: ReturnValueTest
#
# Public functions.
#
def align_comparator_value_with_abi(
abi: ABIFunction, return_value_test: ReturnValueTest
) -> ReturnValueTest:
output_abi_types = _get_abi_types(abi)
comparator = return_value_test.comparator
@ -155,13 +195,19 @@ def _align_comparator_value_with_abi(
)
def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
"""validates a dictionary as valid for use as a condition function ABI"""
def validate_function_abi(
function_abi: Dict, method_name: Optional[str] = None
) -> None:
"""
Validates a dictionary as valid for use as a condition function ABI.
Optionally validates the method_name
"""
abi = ABIFunction(function_abi)
if not abi.get("name"):
raise ValueError(f"Invalid ABI, no function name found {abi}")
if abi.get("name") != method_name:
if method_name and abi.get("name") != method_name:
raise ValueError(f"Mismatched ABI for contract function {method_name} - {abi}")
if abi.get("type") != "function":
raise ValueError(f"Invalid ABI type {abi}")
@ -171,14 +217,47 @@ def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
raise ValueError(f"Invalid ABI stateMutability {abi}")
def _validate_contract_call_abi(
standard_contract_type: str,
function_abi: Dict,
method_name: str,
) -> None:
if not (bool(standard_contract_type) ^ bool(function_abi)):
def get_unbound_contract_function(
contract_address: ChecksumAddress,
method: str,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
) -> ContractFunction:
"""Gets an unbound contract function to evaluate"""
w3 = Web3()
function_abi = _resolve_abi(
w3=w3,
standard_contract_type=standard_contract_type,
method=method,
function_abi=function_abi,
)
try:
contract = w3.eth.contract(address=contract_address, abi=[function_abi])
contract_function = getattr(contract.functions, method)
return contract_function
except Exception as e:
raise ValueError(
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
f"Unable to find contract function, '{method}', for condition: {e}"
) from e
def validate_contract_function_expected_return_type(
contract_function: ContractFunction, return_value_test: ReturnValueTest
) -> None:
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
comparator_value = return_value_test.value
comparator_index = return_value_test.index
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)
if function_abi:
_validate_function_abi(function_abi, method_name=method_name)

View File

@ -4,7 +4,10 @@ from unittest.mock import Mock
import pytest
from nucypher.policy.conditions.base import AccessControlCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import (
AndCompoundCondition,
CompoundAccessControlCondition,
@ -121,39 +124,39 @@ def test_compound_condition_schema_validation(operator, time_condition, rpc_cond
compound_condition_dict = compound_condition.to_dict()
# no issues here
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
# no issues with optional name
compound_condition_dict["name"] = "my_contract_condition"
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# incorrect condition type
compound_condition_dict = compound_condition.to_dict()
compound_condition_dict["condition_type"] = ConditionType.RPC.value
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# invalid operator
compound_condition_dict = compound_condition.to_dict()
compound_condition_dict["operator"] = "5True"
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no operator
compound_condition_dict = compound_condition.to_dict()
del compound_condition_dict["operator"]
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no operands
compound_condition_dict = compound_condition.to_dict()
del compound_condition_dict["operands"]
CompoundAccessControlCondition.validate(compound_condition_dict)
CompoundAccessControlCondition.from_dict(compound_condition_dict)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_and_condition_and_short_circuit(mocker, mock_conditions):
def test_and_condition_and_short_circuit(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
and_condition = AndCompoundCondition(
@ -286,10 +289,9 @@ def test_compound_condition(mock_conditions):
] # or-condition short-circuited because condition_1 was True
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
def test_nested_compound_condition_too_many_nested_levels(
rpc_condition, time_condition
):
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
@ -297,24 +299,23 @@ def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
operands=[
OrCompoundCondition(
operands=[
condition_1,
rpc_condition,
AndCompoundCondition(
operands=[
condition_2,
condition_3,
time_condition,
rpc_condition,
]
),
]
),
condition_4,
time_condition,
]
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
def test_nested_sequential_condition_too_many_nested_levels(
rpc_condition, time_condition
):
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
@ -322,16 +323,16 @@ def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
operands=[
OrCompoundCondition(
operands=[
condition_1,
rpc_condition,
SequentialAccessControlCondition(
condition_variables=[
ConditionVariable("var2", condition_2),
ConditionVariable("var3", condition_3),
ConditionVariable("var2", time_condition),
ConditionVariable("var3", rpc_condition),
]
),
]
),
condition_4,
time_condition,
]
)

View File

@ -55,6 +55,8 @@ class FakeExecutionContractCondition(ContractCondition):
def execute(self, providers: Dict, **context) -> Any:
return self.execution_return_value
EXECUTION_CALL_TYPE = FakeRPCCall
class Schema(ContractCondition.Schema):
@post_load
def make(self, data, **kwargs):
@ -63,9 +65,6 @@ class FakeExecutionContractCondition(ContractCondition):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _create_execution_call(self, *args, **kwargs) -> ContractCall:
return self.FakeRPCCall(*args, **kwargs)
def set_execution_return_value(self, value: Any):
self.execution_call.set_execution_return_value(value)
@ -142,7 +141,7 @@ def test_invalid_contract_condition():
# invalid condition type
with pytest.raises(
InvalidCondition,
match=f"must be instantiated with the {ConditionType.CONTRACT.value} type",
match=f"'condition_type' field - Must be equal to {ConditionType.CONTRACT.value}",
):
_ = ContractCondition(
condition_type=ConditionType.RPC.value,
@ -167,7 +166,7 @@ def test_invalid_contract_condition():
# no abi or contract type
with pytest.raises(
InvalidCondition, match="Provide 'standardContractType' or 'functionAbi'"
InvalidCondition, match="Provide a standard contract type or function ABI"
):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
@ -206,7 +205,9 @@ def test_invalid_contract_condition():
)
# method not in ABI
with pytest.raises(InvalidCondition):
with pytest.raises(
InvalidCondition, match="Could not find any function with matching name"
):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
@ -220,14 +221,55 @@ def test_invalid_contract_condition():
# standard contract type and function ABI
with pytest.raises(
InvalidCondition, match="Provide 'standardContractType' or 'functionAbi'"
InvalidCondition, match="Provide a standard contract type or function ABI"
):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="balanceOf",
chain=TESTERCHAIN_CHAIN_ID,
standard_contract_type="ERC20",
function_abi={"rando": "ABI"},
function_abi={
"inputs": [
{"internalType": "bytes16", "name": "_policyID", "type": "bytes16"}
],
"name": "getPolicy",
"outputs": [
{
"components": [
{
"internalType": "address payable",
"name": "sponsor",
"type": "address",
},
{
"internalType": "uint32",
"name": "startTimestamp",
"type": "uint32",
},
{
"internalType": "uint32",
"name": "endTimestamp",
"type": "uint32",
},
{
"internalType": "uint16",
"name": "size",
"type": "uint16",
},
{
"internalType": "address",
"name": "owner",
"type": "address",
},
],
"internalType": "struct SubscriptionManager.Policy",
"name": "",
"type": "tuple",
}
],
"stateMutability": "view",
"type": "function",
},
return_value_test=ReturnValueTest("!=", 0),
parameters=[
":hrac",
@ -250,17 +292,17 @@ def test_contract_condition_schema_validation():
condition_dict = contract_condition.to_dict()
# no issues here
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
# no issues with optional name
condition_dict["name"] = "my_contract_condition"
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no contract address defined
condition_dict = contract_condition.to_dict()
del condition_dict["contractAddress"]
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
balanceOf_abi = {
"constant": True,
@ -272,29 +314,29 @@ def test_contract_condition_schema_validation():
"type": "function",
}
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no function abi or standard contract type
condition_dict = contract_condition.to_dict()
del condition_dict["standardContractType"]
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# provide both function abi and standard contract type
condition_dict = contract_condition.to_dict()
condition_dict["functionAbi"] = balanceOf_abi
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
# remove standardContractType but specify function abi; no issues with that
condition_dict = contract_condition.to_dict()
del condition_dict["standardContractType"]
condition_dict["functionAbi"] = balanceOf_abi
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no returnValueTest defined
condition_dict = contract_condition.to_dict()
del condition_dict["returnValueTest"]
ContractCondition.validate(condition_dict)
ContractCondition.from_dict(condition_dict)
def test_contract_condition_repr(contract_condition_dict):
@ -389,7 +431,9 @@ def test_abi_bool_output(contract_condition_dict):
assert isinstance(contract_condition.return_value_test.value, bool)
# invalid type fails
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = 23
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -419,7 +463,8 @@ def test_abi_bool_output(contract_condition_dict):
)
# test where context var has invalid expected type(s), so only detected at decryption time
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=True,
@ -438,7 +483,9 @@ def test_abi_uint_output(contract_condition_dict):
assert isinstance(contract_condition.return_value_test.value, int)
# invalid type fails
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = True
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -468,7 +515,8 @@ def test_abi_uint_output(contract_condition_dict):
)
# test where context var has invalid expected type(s), so only detected at decryption time
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=123456789,
@ -487,7 +535,9 @@ def test_abi_int_output(contract_condition_dict):
assert isinstance(contract_condition.return_value_test.value, int)
# invalid type fails
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = [1, 2, 3]
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -517,7 +567,8 @@ def test_abi_int_output(contract_condition_dict):
)
# test where context var has invalid expected type(s), so only detected at decryption time
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=-123456789,
@ -538,7 +589,9 @@ def test_abi_address_output(contract_condition_dict, get_random_checksum_address
assert isinstance(contract_condition.return_value_test.value, str)
# invalid type fails
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = 1.25
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -569,7 +622,8 @@ def test_abi_address_output(contract_condition_dict, get_random_checksum_address
)
# test where context var has invalid expected type(s), so only detected at decryption time
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=checksum_address,
@ -607,7 +661,9 @@ def test_abi_bytes_output(bytes_test_scenario, contract_condition_dict):
assert isinstance(contract_condition.return_value_test.value, str)
# invalid type fails
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = 1.25
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -637,7 +693,8 @@ def test_abi_bytes_output(bytes_test_scenario, contract_condition_dict):
)
# test where context var has invalid expected type(s), so only detected at decryption time
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=call_result_in_bytes,
@ -667,27 +724,37 @@ def test_abi_tuple_output(contract_condition_dict):
assert isinstance(contract_condition.return_value_test.value, Sequence)
# 1. invalid type
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = 1
ContractCondition.from_json(json.dumps(contract_condition_dict))
# 2. invalid number of values
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = [1, 2]
ContractCondition.from_json(json.dumps(contract_condition_dict))
# 3a. Unmatched type
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = [True, 2, 3]
ContractCondition.from_json(json.dumps(contract_condition_dict))
# 3b. Unmatched type
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = [1, False, 3]
ContractCondition.from_json(json.dumps(contract_condition_dict))
# 3c. Unmatched type
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["value"] = [1, 2, 3.14159]
ContractCondition.from_json(json.dumps(contract_condition_dict))
@ -725,7 +792,8 @@ def test_abi_tuple_output(contract_condition_dict):
)
# test where context var has invalid expected type(s) - boolean is unexpected in index 1
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=(1, 2, 3, random_bytes),
@ -794,7 +862,9 @@ def test_abi_tuple_output_with_index(
assert isinstance(contract_condition.return_value_test.value, str)
# invalid type at index
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
contract_condition_dict["returnValueTest"]["index"] = 0
contract_condition_dict["returnValueTest"][
"value"
@ -827,7 +897,8 @@ def test_abi_tuple_output_with_index(
)
# using index, test where context var has invalid expected type - unexpected type in index 2
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=tuple(result),
@ -963,7 +1034,8 @@ def test_abi_multiple_output_values(
)
# test where context var has invalid expected type
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=tuple(result),
@ -998,8 +1070,9 @@ def test_abi_multiple_output_values(
)
# test where context var has invalid expected type
# consequently this is not an invalid condition, but rather an incorrect context value
comparator_value[0][0] = True # should be address but setting to bool
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=tuple(result),
@ -1092,7 +1165,9 @@ def test_abi_nested_tuples_output_values(
[1],
get_random_checksum_address(), # missing tuple value
]
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
ContractCondition.from_json(json.dumps(contract_condition_dict))
contract_condition_dict["returnValueTest"]["value"] = [
@ -1101,7 +1176,9 @@ def test_abi_nested_tuples_output_values(
random_bytes_hex,
get_random_checksum_address(), # incorrect tuple value for Timeframe
]
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
ContractCondition.from_json(json.dumps(contract_condition_dict))
contract_condition_dict["returnValueTest"]["value"] = [
@ -1109,7 +1186,9 @@ def test_abi_nested_tuples_output_values(
[1, random_bytes_hex, 3],
get_random_checksum_address(), # too many values
]
with pytest.raises(InvalidCondition, match="Invalid return value comparison type"):
with pytest.raises(
InvalidConditionLingo, match="Invalid return value comparison type"
):
ContractCondition.from_json(json.dumps(contract_condition_dict))
# process index 1 (bool)
@ -1154,7 +1233,8 @@ def test_abi_nested_tuples_output_values(
)
# test where context var has invalid expected type
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
# consequently this is not an invalid condition, but rather an incorrect context value
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=tuple(result),
@ -1188,8 +1268,9 @@ def test_abi_nested_tuples_output_values(
)
# test where context var has invalid expected type
# consequently this is not an invalid condition, but rather an incorrect context value
comparator_value[0][2] = 1.25 # should be an address
with pytest.raises(InvalidCondition, match="Mismatched comparator type"):
with pytest.raises(ValueError, match="Mismatched comparator type"):
_check_execution_logic(
condition_dict=contract_condition_dict,
execution_result=tuple(result),

View File

@ -25,11 +25,11 @@ def test_jsonpath_field_valid():
def test_jsonpath_field_invalid():
field = JSONPathField()
invalid_jsonpath = "invalid jsonpath"
with pytest.raises(ValidationError) as excinfo:
with pytest.raises(
ValidationError,
match=f"'{invalid_jsonpath}' is not a valid JSONPath expression",
):
field.deserialize(invalid_jsonpath)
assert f"'{invalid_jsonpath}' is not a valid JSONPath expression" in str(
excinfo.value
)
def test_json_api_condition_initialization():
@ -44,24 +44,23 @@ def test_json_api_condition_initialization():
def test_json_api_condition_invalid_type():
with pytest.raises(InvalidCondition) as excinfo:
JsonApiCondition(
with pytest.raises(
InvalidCondition, match="'condition_type' field - Must be equal to json-api"
):
_ = JsonApiCondition(
endpoint="https://api.example.com/data",
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 0),
condition_type="INVALID_TYPE",
)
assert "must be instantiated with the json-api type" in str(excinfo.value)
def test_https_enforcement():
with pytest.raises(InvalidCondition) as excinfo:
JsonApiCondition(
with pytest.raises(InvalidCondition, match="Not a valid URL"):
_ = JsonApiCondition(
endpoint="http://api.example.com/data",
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 0),
)
assert "Not a valid URL" in str(excinfo.value)
def test_json_api_condition_with_primitive_response(mocker):
@ -102,9 +101,8 @@ def test_json_api_condition_fetch_failure(mocker):
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 1),
)
with pytest.raises(InvalidCondition) as excinfo:
with pytest.raises(InvalidCondition, match="Failed to fetch endpoint"):
condition.execution_call._fetch()
assert "Failed to fetch endpoint" in str(excinfo.value)
def test_json_api_condition_verify(mocker):
@ -160,9 +158,10 @@ def test_json_api_condition_verify_invalid_json(mocker):
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 2),
)
with pytest.raises(ConditionEvaluationFailed) as excinfo:
with pytest.raises(
ConditionEvaluationFailed, match="Failed to parse JSON response"
):
condition.verify()
assert "Failed to parse JSON response" in str(excinfo.value)
def test_non_json_response(mocker):
@ -180,11 +179,11 @@ def test_non_json_response(mocker):
return_value_test=ReturnValueTest("==", 18),
)
with pytest.raises(ConditionEvaluationFailed) as excinfo:
with pytest.raises(
ConditionEvaluationFailed, match="Failed to parse JSON response"
):
condition.verify()
assert "Failed to parse JSON response" in str(excinfo.value)
def test_basic_json_api_condition_evaluation_with_parameters(mocker):
mocked_get = mocker.patch(
@ -243,7 +242,5 @@ def test_ambiguous_json_path_multiple_results(mocker):
return_value_test=ReturnValueTest("==", 1),
)
with pytest.raises(ConditionEvaluationFailed) as excinfo:
with pytest.raises(ConditionEvaluationFailed, match="Ambiguous JSONPath query"):
condition.verify()
assert "Ambiguous JSONPath query" in str(excinfo.value)

View File

@ -1,7 +1,10 @@
import pytest
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from tests.constants import TESTERCHAIN_CHAIN_ID
@ -45,7 +48,7 @@ def test_invalid_rpc_condition():
)
# unsupported chain id
with pytest.raises(InvalidCondition, match="is not a permitted blockchain"):
with pytest.raises(InvalidCondition, match="90210 is not a permitted blockchain"):
_ = RPCCondition(
method="eth_getBalance",
chain=90210, # Beverly Hills Chain :)
@ -54,10 +57,10 @@ def test_invalid_rpc_condition():
)
# invalid chain type provided
with pytest.raises(ValueError, match="must be the integer chain ID"):
with pytest.raises(ValueError, match="invalid literal for int"):
_ = RPCCondition(
method="eth_getBalance",
chain=str(TESTERCHAIN_CHAIN_ID), # should be int not str.
chain="chainId", # should be int not str.
return_value_test=ReturnValueTest("==", 0),
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
)
@ -67,44 +70,44 @@ def test_rpc_condition_schema_validation(rpc_condition):
condition_dict = rpc_condition.to_dict()
# no issues here
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
# no issues with optional name
condition_dict["name"] = "my_rpc_condition"
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no chain defined
condition_dict = rpc_condition.to_dict()
del condition_dict["chain"]
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no method defined
condition_dict = rpc_condition.to_dict()
del condition_dict["method"]
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
# no issue with no parameters
condition_dict = rpc_condition.to_dict()
del condition_dict["parameters"]
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no returnValueTest defined
condition_dict = rpc_condition.to_dict()
del condition_dict["returnValueTest"]
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# chain id not an integer
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# chain id not a permitted chain
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
RPCCondition.validate(condition_dict)
RPCCondition.from_dict(condition_dict)
def test_rpc_condition_repr(rpc_condition):

View File

@ -38,15 +38,15 @@ def mock_condition_variables(mocker):
return var_1, var_2, var_3, var_4
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_invalid_sequential_condition(mock_condition_variables):
var_1, *_ = mock_condition_variables
def test_invalid_sequential_condition(rpc_condition, time_condition):
var_1 = ConditionVariable("var1", time_condition)
var_2 = ConditionVariable("var2", rpc_condition)
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.SEQUENTIAL.value):
_ = SequentialAccessControlCondition(
condition_type=ConditionType.TIME.value,
condition_variables=list(mock_condition_variables),
condition_variables=[var_1, var_2],
)
# no variables
@ -62,18 +62,29 @@ def test_invalid_sequential_condition(mock_condition_variables):
)
# too many variables
too_many_variables = list(mock_condition_variables)
too_many_variables.extend(mock_condition_variables) # duplicate list length
too_many_variables = [var_1, var_2, var_1, var_2]
too_many_variables.extend(too_many_variables) # duplicate list length
assert len(too_many_variables) > SequentialAccessControlCondition.MAX_NUM_CONDITIONS
with pytest.raises(InvalidCondition, match="Maximum of"):
_ = SequentialAccessControlCondition(
condition_variables=too_many_variables,
)
# duplicate var names
dupe_var = ConditionVariable(var_1.var_name, condition=var_2.condition)
with pytest.raises(InvalidCondition, match="Duplicate"):
_ = SequentialAccessControlCondition(
condition_variables=[var_1, var_2, dupe_var],
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_sequential_condition_too_many_nested_levels(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
def test_nested_sequential_condition_too_many_nested_levels(
rpc_condition, time_condition
):
var_1 = ConditionVariable("var1", time_condition)
var_2 = ConditionVariable("var2", rpc_condition)
var_3 = ConditionVariable("var3", time_condition)
var_4 = ConditionVariable("var4", rpc_condition)
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
@ -104,9 +115,13 @@ def test_nested_sequential_condition_too_many_nested_levels(mock_condition_varia
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_compound_condition_too_many_nested_levels(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
def test_nested_compound_condition_too_many_nested_levels(
rpc_condition, time_condition
):
var_1 = ConditionVariable("var1", time_condition)
var_2 = ConditionVariable("var2", rpc_condition)
var_3 = ConditionVariable("var3", time_condition)
var_4 = ConditionVariable("var4", rpc_condition)
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"

View File

@ -1,6 +1,9 @@
import pytest
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
from tests.constants import TESTERCHAIN_CHAIN_ID
@ -37,38 +40,38 @@ def test_time_condition_schema_validation(time_condition):
condition_dict = time_condition.to_dict()
# no issues here
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
# no issues with optional name
condition_dict["name"] = "my_time_machine"
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no method
condition_dict = time_condition.to_dict()
del condition_dict["method"]
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# no returnValueTest defined
condition_dict = time_condition.to_dict()
del condition_dict["returnValueTest"]
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# invalid method name
condition_dict["method"] = "my_blocktime"
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# chain id not an integer
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
with pytest.raises(InvalidCondition):
with pytest.raises(InvalidConditionLingo):
# chain id not a permitted chain
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
TimeCondition.validate(condition_dict)
TimeCondition.from_dict(condition_dict)
@pytest.mark.parametrize(