Merge pull request #3500 from derekpierre/sequential-condition

Sequential Conditions
pull/3557/head
Derek Pierre 2024-09-23 11:28:30 -04:00 committed by GitHub
commit 6a54f82f0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1252 additions and 555 deletions

View File

@ -0,0 +1 @@
Support for executing multiple conditions sequentially, where the outcome of one condition can be used as input for another.

View File

@ -1,14 +1,15 @@
import json
from abc import ABC, abstractmethod
from base64 import b64decode, b64encode
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Optional, Tuple
from marshmallow import Schema, ValidationError
from marshmallow import Schema, ValidationError, fields
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.utils import CamelCaseSchema
class _Serializable:
@ -51,24 +52,30 @@ class _Serializable:
class AccessControlCondition(_Serializable, ABC):
CONDITION_TYPE = NotImplemented
class Schema(Schema):
name = NotImplemented
def __init__(self):
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
condition_type = NotImplemented
def __init__(self, condition_type: str, name: Optional[str] = None):
super().__init__()
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.name = name
# validate inputs using marshmallow schema
schema = self.Schema()
errors = schema.validate(self.to_dict())
if errors:
raise InvalidConditionLingo(errors)
self.validate(self.to_dict())
@abstractmethod
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
"""Returns the boolean result of the evaluation and the returned value in a two-tuple."""
return NotImplemented
raise NotImplementedError
@classmethod
def validate(cls, data: Dict) -> None:
@ -89,3 +96,48 @@ class AccessControlCondition(_Serializable, ABC):
return super().from_json(data)
except ValidationError as e:
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
class ExecutionCall(ABC):
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
raise NotImplementedError
class MultiConditionAccessControl(AccessControlCondition):
MAX_NUM_CONDITIONS = 5
MAX_MULTI_CONDITION_NESTED_LEVEL = 2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._validate_multi_condition_nesting(conditions=self.conditions)
@property
@abstractmethod
def conditions(self) -> List[AccessControlCondition]:
raise NotImplementedError
@classmethod
def _validate_multi_condition_nesting(
cls,
conditions: List[AccessControlCondition],
current_level: int = 1,
):
if len(conditions) > cls.MAX_NUM_CONDITIONS:
raise InvalidCondition(
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
)
for condition in conditions:
if not isinstance(condition, MultiConditionAccessControl):
continue
level = current_level + 1
if level > cls.MAX_MULTI_CONDITION_NESTED_LEVEL:
raise InvalidCondition(
f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed"
)
cls._validate_multi_condition_nesting(
conditions=condition.conditions,
current_level=level,
)

View File

@ -1,6 +1,6 @@
import re
from functools import partial
from typing import Any, List, Union
from typing import Any, List, Optional, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
@ -125,9 +125,11 @@ def _resolve_context_variable(param: Union[Any, List[Any]], **context):
return param
def resolve_any_context_variables(parameters: List[Any], return_value_test, **context):
processed_parameters = [
_resolve_context_variable(param, **context) for param in parameters
]
processed_return_value_test = return_value_test.with_resolved_context(**context)
return processed_parameters, processed_return_value_test
def resolve_parameter_context_variables(parameters: Optional[List[Any]], **context):
if not parameters:
processed_parameters = [] # produce empty list
else:
processed_parameters = [
_resolve_context_variable(param, **context) for param in parameters
]
return processed_parameters

View File

@ -11,7 +11,6 @@ from typing import (
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
from marshmallow.validate import OneOf
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
@ -19,22 +18,29 @@ from web3.providers import BaseProvider
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
from nucypher.policy.conditions.base import AccessControlCondition
from nucypher.policy.conditions.base import (
ExecutionCall,
)
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_any_context_variables,
resolve_parameter_context_variables,
)
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake
from nucypher.policy.conditions.lingo import (
ConditionType,
ExecutionCallAccessControlCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.policy.conditions.validation import (
_align_comparator_value_with_abi,
_get_abi_types,
_validate_condition_abi,
_validate_contract_call_abi,
_validate_multiple_output_types,
_validate_single_output_type,
)
@ -102,94 +108,70 @@ def _validate_chain(chain: int) -> None:
)
class RPCCondition(AccessControlCondition):
ETH_PREFIX = "eth_"
class RPCCall(ExecutionCall):
LOG = logging.Logger(__name__)
ALLOWED_METHODS = {
# RPC
"eth_getBalance": int,
} # TODO other allowed methods (tDEC #64)
LOG = logging.Logger(__name__)
CONDITION_TYPE = ConditionType.RPC.value
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
@post_load
def make(self, data, **kwargs):
return RPCCondition(**data)
def __repr__(self) -> str:
r = f"{self.__class__.__name__}(function={self.method}, chain={self.chain})"
return r
def __init__(
self,
chain: int,
method: str,
return_value_test: ReturnValueTest,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
parameters: Optional[List[Any]] = None,
):
# Validate input
# TODO: Additional validation (function is valid for ABI, RVT validity, standard contract name validity, etc.)
_validate_chain(chain=chain)
# internal
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.name = name
self.chain = chain
self.provider: Optional[BaseProvider] = None # set in _configure_provider
self.method = self._validate_method(method=method)
# test
# should not be set to None - we do list unpacking so cannot be None; use empty list
self.parameters = parameters or []
self.return_value_test = return_value_test # output
self._validate_expected_return_type()
super().__init__()
self.parameters = parameters or None
def _validate_method(self, method):
if not method:
raise InvalidCondition("Undefined method name")
raise ValueError("Undefined method name")
if method not in self.ALLOWED_METHODS:
raise InvalidCondition(
raise ValueError(
f"'{method}' is not a permitted RPC endpoint for condition evaluation."
)
return method
def _validate_expected_return_type(self):
expected_return_type = self.ALLOWED_METHODS[self.method]
comparator_value = self.return_value_test.value
if is_context_variable(comparator_value):
return
def _get_web3_py_function(self, w3: Web3, rpc_method: str):
web3_py_method = camel_case_to_snake(rpc_method)
rpc_function = getattr(
w3.eth, web3_py_method
) # bind contract function (only exposes the eth API)
return rpc_function
if not isinstance(self.return_value_test.value, expected_return_type):
raise InvalidCondition(
f"Return value comparison for '{self.method}' call output "
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
def _check_chain_id(self, w3: Web3) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = w3.eth.chain_id
if provider_chain != self.chain:
raise NoConnectionToChain(
chain=self.chain,
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}",
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
w3 = self._configure_w3(provider=provider)
self._check_chain_id(w3)
return w3
def _next_endpoint(
self, providers: Dict[int, Set[HTTPProvider]]
) -> Iterator[HTTPProvider]:
@ -208,47 +190,99 @@ class RPCCondition(AccessControlCondition):
# each endpoint here to see if it's alive and only yield it if it is.
yield provider
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
self.provider = provider
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
resolved_parameters = resolve_parameter_context_variables(
self.parameters, **context
)
return w3
def _check_chain_id(self) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = self.w3.eth.chain_id
if provider_chain != self.chain:
raise InvalidCondition(
f"This condition can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}"
endpoints = self._next_endpoint(providers=providers)
latest_error = ""
for provider in endpoints:
w3 = self._configure_provider(provider)
try:
result = self._execute(w3, resolved_parameters)
break
except RequiredContextVariable:
raise
except Exception as e:
latest_error = f"RPC call '{self.method}' failed: {e}"
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
# Something went wrong. Try the next endpoint.
continue
else:
# Fuck.
raise RPCExecutionFailed(
f"RPC call '{self.method}' failed; latest error - {latest_error}"
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
self.w3 = self._configure_w3(provider=provider)
self._check_chain_id()
return provider
return result
def _get_web3_py_function(self, rpc_method: str):
web3_py_method = camel_case_to_snake(rpc_method)
rpc_function = getattr(
self.w3.eth, web3_py_method
) # bind contract function (only exposes the eth API)
return rpc_function
def _execute_call(self, parameters: List[Any]) -> Any:
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
rpc_endpoint_, rpc_method = self.method.split("_", 1)
rpc_function = self._get_web3_py_function(rpc_method)
rpc_result = rpc_function(*parameters) # RPC read
rpc_function = self._get_web3_py_function(w3, rpc_method)
rpc_result = rpc_function(*resolved_parameters) # RPC read
return rpc_result
class RPCCondition(ExecutionCallAccessControlCondition):
CONDITION_TYPE = ConditionType.RPC.value
class Schema(ExecutionCallAccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
@post_load
def make(self, data, **kwargs):
return RPCCondition(**data)
def __repr__(self) -> str:
r = f"{self.__class__.__name__}(function={self.method}, chain={self.chain})"
return r
def __init__(
self,
condition_type: str = CONDITION_TYPE,
*args,
**kwargs,
):
super().__init__(condition_type=condition_type, *args, **kwargs)
self._validate_expected_return_type()
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return RPCCall(*args, **kwargs)
@property
def method(self):
return self.execution_call.method
@property
def chain(self):
return self.execution_call.chain
@property
def parameters(self):
return self.execution_call.parameters
def _validate_expected_return_type(self):
expected_return_type = RPCCall.ALLOWED_METHODS[self.method]
comparator_value = self.return_value_test.value
if is_context_variable(comparator_value):
return
if not isinstance(self.return_value_test.value, expected_return_type):
raise InvalidCondition(
f"Return value comparison for '{self.method}' call output "
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
)
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:
@ -257,35 +291,77 @@ class RPCCondition(AccessControlCondition):
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
) -> Tuple[bool, Any]:
"""
Verifies the onchain condition is met by performing a
read operation and evaluating the return value test.
"""
parameters, return_value_test = resolve_any_context_variables(
self.parameters, self.return_value_test, **context
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
)
return_value_test = self._align_comparator_value_with_abi(return_value_test)
endpoints = self._next_endpoint(providers=providers)
for provider in endpoints:
self._configure_provider(provider=provider)
try:
result = self._execute_call(parameters=parameters)
break
except Exception as e:
self.LOG.warn(
f"RPC call '{self.method}' failed: {e}, attempting to try next endpoint."
)
# Something went wrong. Try the next endpoint.
continue
else:
# Fuck.
raise RPCExecutionFailed(f"Contract call '{self.method}' failed.")
return_value_test = self._align_comparator_value_with_abi(
resolved_return_value_test
)
result = self.execution_call.execute(providers=providers, **context)
eval_result = return_value_test.eval(result) # test
return eval_result, result
class ContractCall(RPCCall):
def __init__(
self,
method: str,
contract_address: ChecksumAddress,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
**kwargs,
):
if not method:
raise ValueError("Undefined method name")
_validate_contract_call_abi(
standard_contract_type, function_abi, method_name=method
)
# preprocessing
contract_address = to_checksum_address(contract_address)
self.contract_address = contract_address
self.standard_contract_type = standard_contract_type
self.function_abi = function_abi
super().__init__(method=method, *args, **kwargs)
self.contract_function = self._get_unbound_contract_function()
def _validate_method(self, method):
return method
def _get_unbound_contract_function(self) -> ContractFunction:
"""Gets an unbound contract function to evaluate for this condition"""
w3 = Web3()
function_abi = _resolve_abi(
w3=w3,
standard_contract_type=self.standard_contract_type,
method=self.method,
function_abi=self.function_abi,
)
try:
contract = w3.eth.contract(
address=self.contract_address, abi=[function_abi]
)
contract_function = getattr(contract.functions, self.method)
return contract_function
except Exception as e:
raise ValueError(
f"Unable to find contract function, '{self.method}', for condition: {e}"
)
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
self.contract_function.w3 = w3
bound_contract_function = self.contract_function(
*resolved_parameters
) # bind contract function
contract_result = bound_contract_function.call() # onchain read
return contract_result
class ContractCondition(RPCCondition):
CONDITION_TYPE = ConditionType.CONTRACT.value
@ -293,8 +369,8 @@ class ContractCondition(RPCCondition):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
)
standard_contract_type = fields.Str(required=False)
contract_address = fields.Str(required=True)
standard_contract_type = fields.Str(required=False)
function_abi = fields.Dict(required=False)
@post_load
@ -306,7 +382,7 @@ class ContractCondition(RPCCondition):
standard_contract_type = data.get("standard_contract_type")
function_abi = data.get("function_abi")
try:
_validate_condition_abi(
_validate_contract_call_abi(
standard_contract_type, function_abi, method_name=data.get("method")
)
except ValueError as e:
@ -314,64 +390,37 @@ class ContractCondition(RPCCondition):
def __init__(
self,
method: str,
contract_address: ChecksumAddress,
condition_type: str = CONDITION_TYPE,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
**kwargs,
):
if not method:
raise InvalidCondition("Undefined method name")
try:
_validate_condition_abi(
standard_contract_type, function_abi, method_name=method
)
except ValueError as e:
raise InvalidCondition(str(e))
self.method = method
self.w3 = Web3() # used to instantiate contract function without a provider
# preprocessing
contract_address = to_checksum_address(contract_address)
# spec
self.contract_address = contract_address
self.condition_type = condition_type
self.standard_contract_type = standard_contract_type
self.function_abi = function_abi
self.contract_function = self._get_unbound_contract_function()
# call to super must be at the end for proper validation
super().__init__(condition_type=condition_type, method=method, *args, **kwargs)
super().__init__(condition_type=condition_type, *args, **kwargs)
def _validate_method(self, method):
# overrides validate method used by rpc superclass
return method
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return ContractCall(*args, **kwargs)
@property
def function_abi(self):
return self.execution_call.function_abi
@property
def standard_contract_type(self):
return self.execution_call.standard_contract_type
@property
def contract_function(self):
return self.execution_call.contract_function
@property
def contract_address(self):
return self.execution_call.contract_address
def _validate_expected_return_type(self) -> None:
output_abi_types = _get_abi_types(self.contract_function.contract_abi[0])
comparator_value = self.return_value_test.value
comparator_index = self.return_value_test.index
index_string = (
f"@index={comparator_index}" if comparator_index is not None else ""
_validate_contract_function_expected_return_type(
contract_function=self.contract_function,
return_value_test=self.return_value_test,
)
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{self.contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)
def __repr__(self) -> str:
r = (
@ -381,37 +430,6 @@ class ContractCondition(RPCCondition):
)
return r
def _configure_provider(self, *args, **kwargs):
super()._configure_provider(*args, **kwargs)
self.contract_function.w3 = self.w3
def _get_unbound_contract_function(self) -> ContractFunction:
"""Gets an unbound contract function to evaluate for this condition"""
function_abi = _resolve_abi(
w3=self.w3,
standard_contract_type=self.standard_contract_type,
method=self.method,
function_abi=self.function_abi,
)
try:
contract = self.w3.eth.contract(
address=self.contract_address, abi=[function_abi]
)
contract_function = getattr(contract.functions, self.method)
return contract_function
except Exception as e:
raise InvalidCondition(
f"Unable to find contract function, '{self.method}', for condition: {e}"
)
def _execute_call(self, parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
bound_contract_function = self.contract_function(
*parameters
) # bind contract function
contract_result = bound_contract_function.call() # onchain read
return contract_result
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:
@ -419,3 +437,25 @@ class ContractCondition(RPCCondition):
abi=self.contract_function.contract_abi[0],
return_value_test=return_value_test,
)
def _validate_contract_function_expected_return_type(
contract_function: ContractFunction, return_value_test: ReturnValueTest
) -> None:
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
comparator_value = return_value_test.value
comparator_index = return_value_test.index
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)

View File

@ -7,9 +7,9 @@ class InvalidConditionLingo(Exception):
class NoConnectionToChain(RuntimeError):
"""Raised when a node does not have an associated provider for a chain."""
def __init__(self, chain: int):
def __init__(self, chain: int, message: str = None):
self.chain = chain
message = f"No connection to chain ID {chain}"
message = message or f"No connection to chain ID {chain}"
super().__init__(message)

View File

@ -2,9 +2,10 @@ import ast
import base64
import json
import operator as pyoperator
from abc import abstractmethod
from enum import Enum
from hashlib import md5
from typing import Any, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from hexbytes import HexBytes
from marshmallow import (
@ -19,8 +20,14 @@ from marshmallow import (
)
from marshmallow.validate import OneOf, Range
from packaging.version import parse as parse_version
from web3 import HTTPProvider
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionCall,
MultiConditionAccessControl,
_Serializable,
)
from nucypher.policy.conditions.context import (
_resolve_context_variable,
is_context_variable,
@ -52,19 +59,8 @@ class _ConditionField(fields.Dict):
instance = condition_class.from_dict(condition_data)
return instance
#
# CONDITION = BASE_CONDITION | COMPOUND_CONDITION
#
# BASE_CONDITION = {
# // ..
# }
#
# COMPOUND_CONDITION = {
# "operator": OPERATOR,
# "operands": [CONDITION*]
# }
# CONDITION = TIME | CONTRACT | RPC | JSON_API | COMPOUND | SEQUENTIAL
class ConditionType(Enum):
"""
Defines the types of conditions that can be evaluated.
@ -75,13 +71,27 @@ class ConditionType(Enum):
RPC = "rpc"
JSONAPI = "json-api"
COMPOUND = "compound"
SEQUENTIAL = "sequential"
@classmethod
def values(cls) -> List[str]:
return [condition.value for condition in cls]
class CompoundAccessControlCondition(AccessControlCondition):
class CompoundAccessControlCondition(MultiConditionAccessControl):
"""
A combination of two or more conditions connected by logical operators such as AND, OR, NOT.
CompoundCondition grammar:
OPERATOR = AND | OR | NOT
COMPOUND_CONDITION = {
"name": ... (Optional)
"conditionType": "compound",
"operator": OPERATOR,
"operands": [CONDITION*]
}
"""
AND_OPERATOR = "and"
OR_OPERATOR = "or"
NOT_OPERATOR = "not"
@ -93,28 +103,32 @@ class CompoundAccessControlCondition(AccessControlCondition):
def _validate_operator_and_operands(
cls,
operator: str,
operands: List,
operands: List[Union[Dict, AccessControlCondition]],
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
if operator not in cls.OPERATORS:
raise exception_class(f"{operator} is not a valid operator")
num_operands = len(operands)
if operator == cls.NOT_OPERATOR:
if len(operands) != 1:
if num_operands != 1:
raise exception_class(
f"Only 1 operand permitted for '{operator}' compound condition"
)
elif len(operands) < 2:
elif num_operands < 2:
raise exception_class(
f"Minimum of 2 operand needed for '{operator}' compound condition"
)
elif num_operands > cls.MAX_NUM_CONDITIONS:
raise exception_class(
f"Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition"
)
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
class Schema(AccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False)
operator = fields.Str(required=True)
operands = fields.List(_ConditionField, required=True)
@ -147,20 +161,16 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*]
}
"""
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_operator_and_operands(operator, operands, InvalidCondition)
self.condition_type = condition_type
self.operator = operator
self.operands = operands
self.condition_type = condition_type
self.name = name
self.id = md5(bytes(self)).hexdigest()[:6]
super().__init__(condition_type=condition_type, name=name)
def __repr__(self):
return f"Operator={self.operator} (NumOperands={len(self.operands)}), id={self.id})"
@ -186,6 +196,10 @@ class CompoundAccessControlCondition(AccessControlCondition):
return overall_result, values
@property
def conditions(self):
return self.operands
class OrCompoundCondition(CompoundAccessControlCondition):
def __init__(self, operands: List[AccessControlCondition]):
@ -212,6 +226,126 @@ _COMPARATOR_FUNCTIONS = {
}
class ConditionVariable(_Serializable):
class Schema(CamelCaseSchema):
var_name = fields.Str(required=True) # TODO: should this be required?
condition = _ConditionField(required=True)
@post_load
def make(self, data, **kwargs):
return ConditionVariable(**data)
def __init__(self, var_name: str, condition: AccessControlCondition):
self.var_name = var_name
self.condition = condition
class SequentialAccessControlCondition(MultiConditionAccessControl):
"""
A series of conditions that are evaluated in a specific order, where the result of one
condition can be used in subsequent conditions.
SequentialCondition grammar:
CONDITION_VARIABLE = {
"varName": STR,
"condition": {
CONDITION
}
}
SEQUENTIAL_CONDITION = {
"name": ... (Optional)
"conditionType": "sequential",
"conditionVariables": [CONDITION_VARIABLE*]
}
"""
CONDITION_TYPE = ConditionType.SEQUENTIAL.value
@classmethod
def _validate_condition_variables(
cls,
condition_variables: List[Union[Dict, ConditionVariable]],
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
num_condition_variables = len(condition_variables)
if num_condition_variables < 2:
raise exception_class("At least two conditions must be specified")
if num_condition_variables > cls.MAX_NUM_CONDITIONS:
raise exception_class(
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
)
class Schema(AccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
)
condition_variables = fields.List(
fields.Nested(ConditionVariable.Schema(), required=True)
)
# maintain field declaration ordering
class Meta:
ordered = True
@validates_schema
def validate_condition_variables(self, data, **kwargs):
condition_variables = data["condition_variables"]
SequentialAccessControlCondition._validate_condition_variables(
condition_variables, ValidationError
)
@post_load
def make(self, data, **kwargs):
return SequentialAccessControlCondition(**data)
def __init__(
self,
condition_variables: List[ConditionVariable],
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
self._validate_condition_variables(
condition_variables=condition_variables, exception_class=InvalidCondition
)
self.condition_variables = condition_variables
super().__init__(condition_type=condition_type, name=name)
def __repr__(self):
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
return r
# TODO - think about not dereferencing context but using a dict;
# may allows more freedom for params
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
) -> Tuple[bool, Any]:
values = []
latest_success = False
inner_context = dict(context) # don't modify passed in context - use a copy
# resolve variables
for condition_variable in self.condition_variables:
latest_success, result = condition_variable.condition.verify(
providers=providers, **inner_context
)
values.append(result)
if not latest_success:
# short circuit due to failed condition
break
inner_context[f":{condition_variable.var_name}"] = result
return latest_success, values
@property
def conditions(self):
return [
condition_variable.condition
for condition_variable in self.condition_variables
]
class ReturnValueTest:
class InvalidExpression(ValueError):
pass
@ -357,16 +491,6 @@ class ConditionLingo(_Serializable):
"""
def __init__(self, condition: AccessControlCondition, version: str = VERSION):
"""
CONDITION = BASE_CONDITION | COMPOUND_CONDITION
BASE_CONDITION = {
// ..
}
COMPOUND_CONDITION = {
"operator": OPERATOR,
"operands": [CONDITION*]
}
"""
self.condition = condition
self.check_version_compatibility(version)
self.version = version
@ -412,8 +536,7 @@ class ConditionLingo(_Serializable):
cls, condition: ConditionDict, version: int = None
) -> Type[AccessControlCondition]:
"""
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
Inspects a given block of JSON and attempts to resolve it's intended datatype within the
conditions expression framework.
"""
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
@ -429,12 +552,13 @@ class ConditionLingo(_Serializable):
RPCCondition,
CompoundAccessControlCondition,
JsonApiCondition,
SequentialAccessControlCondition,
):
if condition.CONDITION_TYPE == condition_type:
return condition
raise InvalidConditionLingo(
f"Cannot resolve condition lingo with condition type {condition_type}"
f"Cannot resolve condition lingo, {condition}, with condition type {condition_type}"
)
@classmethod
@ -443,3 +567,50 @@ class ConditionLingo(_Serializable):
raise InvalidConditionLingo(
f"Version provided, {version}, is incompatible with current version {cls.VERSION}"
)
class ExecutionCallAccessControlCondition(AccessControlCondition):
"""
Conditions that utilize underlying ExecutionCall objects.
"""
class Schema(AccessControlCondition.Schema):
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
def __init__(
self,
condition_type: str,
return_value_test: ReturnValueTest,
name: Optional[str] = None,
*args,
**kwargs,
):
self.return_value_test = return_value_test
try:
self.execution_call = self._create_execution_call(*args, **kwargs)
except ValueError as e:
raise InvalidCondition(str(e))
super().__init__(condition_type=condition_type, name=name)
@abstractmethod
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
"""
Returns the execution call that the condition executes.
"""
raise NotImplementedError
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
"""
Verifies the condition is met by performing execution call and
evaluating the return value test.
"""
result = self.execution_call.execute(*args, **kwargs)
resolved_return_value_test = self.return_value_test.with_resolved_context(
**kwargs
)
eval_result = resolved_return_value_test.eval(result) # test
return eval_result, result

View File

@ -6,13 +6,15 @@ from jsonpath_ng.ext import parse
from marshmallow import fields, post_load, validate
from marshmallow.fields import Field, Url
from nucypher.policy.conditions.base import AccessControlCondition
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
InvalidCondition,
)
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema
from nucypher.policy.conditions.lingo import (
ConditionType,
ExecutionCallAccessControlCondition,
)
from nucypher.utilities.logging import Logger
@ -32,58 +34,29 @@ class JSONPathField(Field):
return value
class JsonApiCondition(AccessControlCondition):
"""
A JSON API condition is a condition that can be evaluated by reading from a JSON
HTTPS endpoint. The response must return an HTTP 200 with valid JSON in the response body.
The response will be deserialized as JSON and parsed using jsonpath.
"""
CONDITION_TYPE = ConditionType.JSONAPI.value
LOGGER = Logger("nucypher.policy.conditions.JsonApiCondition")
class JsonApiCall(ExecutionCall):
TIMEOUT = 5 # seconds
class Schema(CamelCaseSchema):
name = fields.Str(required=False)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
)
parameters = fields.Dict(required=False, allow_none=True)
endpoint = Url(required=True, relative=False, schemes=["https"])
query = JSONPathField(required=False, allow_none=True)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
@post_load
def make(self, data, **kwargs):
return JsonApiCondition(**data)
def __init__(
self,
endpoint: str,
return_value_test: ReturnValueTest,
query: Optional[str] = None,
parameters: Optional[dict] = None,
condition_type: str = ConditionType.JSONAPI.value,
query: Optional[str] = None,
):
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.endpoint = endpoint
self.parameters = parameters or {}
self.query = query
self.return_value_test = return_value_test
self.timeout = self.TIMEOUT
self.logger = self.LOGGER
self.logger = Logger(__name__)
super().__init__()
def execute(self, *args, **kwargs) -> Any:
response = self._fetch()
data = self._deserialize_response(response)
result = self._query_response(data)
return result
def fetch(self) -> requests.Response:
def _fetch(self) -> requests.Response:
"""Fetches data from the endpoint."""
try:
response = requests.get(
@ -111,7 +84,7 @@ class JsonApiCondition(AccessControlCondition):
return response
def deserialize_response(self, response: requests.Response) -> Any:
def _deserialize_response(self, response: requests.Response) -> Any:
"""Deserializes the JSON response from the endpoint."""
try:
data = response.json()
@ -122,7 +95,7 @@ class JsonApiCondition(AccessControlCondition):
)
return data
def query_response(self, data: Any) -> Any:
def _query_response(self, data: Any) -> Any:
if not self.query:
return data # primitive value
@ -148,6 +121,55 @@ class JsonApiCondition(AccessControlCondition):
return result
class JsonApiCondition(ExecutionCallAccessControlCondition):
"""
A JSON API condition is a condition that can be evaluated by reading from a JSON
HTTPS endpoint. The response must return an HTTP 200 with valid JSON in the response body.
The response will be deserialized as JSON and parsed using jsonpath.
"""
CONDITION_TYPE = ConditionType.JSONAPI.value
class Schema(ExecutionCallAccessControlCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
)
endpoint = Url(required=True, relative=False, schemes=["https"])
parameters = fields.Dict(required=False, allow_none=True)
query = JSONPathField(required=False, allow_none=True)
@post_load
def make(self, data, **kwargs):
return JsonApiCondition(**data)
def __init__(
self,
condition_type: str = ConditionType.JSONAPI.value,
*args,
**kwargs,
):
super().__init__(condition_type=condition_type, *args, **kwargs)
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return JsonApiCall(*args, **kwargs)
@property
def endpoint(self):
return self.execution_call.endpoint
@property
def query(self):
return self.execution_call.query
@property
def parameters(self):
return self.execution_call.parameters
@property
def timeout(self):
return self.execution_call.timeout
@staticmethod
def _process_result_for_eval(result: Any):
# strings that are not already quoted will cause a problem for literal_eval
@ -170,14 +192,11 @@ class JsonApiCondition(AccessControlCondition):
and evaluating the return value test with the result. Parses the endpoint's JSON response using
JSONPath.
"""
response = self.fetch()
data = self.deserialize_response(response)
result = self.query_response(data)
result = self.execution_call.execute(**context)
result_for_eval = self._process_result_for_eval(result)
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
)
result_for_eval = self._process_result_for_eval(result)
eval_result = resolved_return_value_test.eval(result_for_eval) # test
return eval_result, result

View File

@ -1,33 +1,53 @@
from typing import Any, List, Optional
from marshmallow import fields, post_load, validate
from marshmallow.validate import Equal, OneOf
from marshmallow.validate import Equal
from web3 import Web3
from nucypher.policy.conditions.evm import _CONDITION_CHAINS, RPCCondition
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema
from nucypher.policy.conditions.lingo import ConditionType
class TimeRPCCall(RPCCall):
METHOD = "blocktime"
def __init__(
self,
chain: int,
method: str = METHOD,
parameters: Optional[List[Any]] = None,
):
if parameters:
raise ValueError(f"{self.METHOD} does not take any parameters")
super().__init__(chain=chain, method=method, parameters=parameters)
def _validate_method(self, method):
if method != self.METHOD:
raise ValueError(
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
)
return method
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
# TODO may need to rethink as part of #3051 (multicall work).
latest_block = w3.eth.get_block("latest")
return latest_block.timestamp
class TimeCondition(RPCCondition):
METHOD = "blocktime"
CONDITION_TYPE = ConditionType.TIME.value
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
class Schema(RPCCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.TIME.value), required=True
)
name = fields.Str(required=False)
chain = fields.Int(
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
)
method = fields.Str(
dump_default="blocktime", required=True, validate=Equal("blocktime")
)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
@post_load
def make(self, data, **kwargs):
@ -39,28 +59,21 @@ class TimeCondition(RPCCondition):
def __init__(
self,
return_value_test: ReturnValueTest,
chain: int,
method: str = METHOD,
method: str = TimeRPCCall.METHOD,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
*args,
**kwargs,
):
if method != self.METHOD:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
)
# call to super must be at the end for proper validation
super().__init__(
chain=chain,
method=method,
return_value_test=return_value_test,
name=name,
condition_type=condition_type,
method=method,
*args,
**kwargs,
)
def _validate_method(self, method):
return method
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
return TimeRPCCall(*args, **kwargs)
def _validate_expected_return_type(self):
comparator_value = self.return_value_test.value
@ -72,9 +85,3 @@ class TimeCondition(RPCCondition):
@property
def timestamp(self):
return self.return_value_test.value
def _execute_call(self, parameters: List[Any]) -> Any:
"""Execute onchain read and return result."""
# TODO may need to rethink as part of #3051 (multicall work).
latest_block = self.w3.eth.get_block("latest")
return latest_block.timestamp

View File

@ -34,16 +34,20 @@ class ReturnValueTestDict(TypedDict):
key: NotRequired[Union[str, int]]
# Conditions
class _AccessControlCondition(TypedDict):
name: NotRequired[str]
class RPCConditionDict(_AccessControlCondition):
conditionType: str
class BaseExecConditionDict(_AccessControlCondition):
returnValueTest: ReturnValueTestDict
class RPCConditionDict(BaseExecConditionDict):
chain: int
method: str
parameters: NotRequired[List[Any]]
returnValueTest: ReturnValueTestDict
class TimeConditionDict(RPCConditionDict):
@ -56,17 +60,43 @@ class ContractConditionDict(RPCConditionDict):
functionAbi: NotRequired[ABIFunction]
class JsonApiConditionDict(BaseExecConditionDict):
endpoint: str
query: NotRequired[str]
parameters: NotRequired[Dict]
#
# CompoundCondition represents:
# {
# "operator": ["and" | "or"]
# "operands": List[AccessControlCondition | CompoundCondition]
# "operator": ["and" | "or" | "not"]
# "operands": List[AccessControlCondition]
# }
#
class CompoundConditionDict(_AccessControlCondition):
operator: Literal["and", "or", "not"]
operands: List["ConditionDict"]
#
class CompoundConditionDict(TypedDict):
conditionType: str
operator: Literal["and", "or"]
operands: List["Lingo"]
# ConditionVariable represents:
# {
# varName: str
# condition: AccessControlCondition
# }
#
class ConditionVariableDict(TypedDict):
varName: str
condition: "ConditionDict"
#
# SequentialCondition represents:
# {
# "conditionVariables": List[ConditionVariable]
# }
#
class SequentialConditionDict(_AccessControlCondition):
conditionVariables = List[ConditionVariableDict]
#
@ -75,8 +105,15 @@ class CompoundConditionDict(TypedDict):
# - RPCCondition
# - ContractCondition
# - CompoundConditionDict
# - JsonApiConditionDict
# - SequentialConditionDict
ConditionDict = Union[
TimeConditionDict, RPCConditionDict, ContractConditionDict, CompoundConditionDict
TimeConditionDict,
RPCConditionDict,
ContractConditionDict,
CompoundConditionDict,
JsonApiConditionDict,
SequentialConditionDict,
]

View File

@ -155,7 +155,7 @@ def _align_comparator_value_with_abi(
)
def _validate_condition_function_abi(function_abi: Dict, method_name: str) -> None:
def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
"""validates a dictionary as valid for use as a condition function ABI"""
abi = ABIFunction(function_abi)
@ -171,7 +171,7 @@ def _validate_condition_function_abi(function_abi: Dict, method_name: str) -> No
raise ValueError(f"Invalid ABI stateMutability {abi}")
def _validate_condition_abi(
def _validate_contract_call_abi(
standard_contract_type: str,
function_abi: Dict,
method_name: str,
@ -181,4 +181,4 @@ def _validate_condition_abi(
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
)
if function_abi:
_validate_condition_function_abi(function_abi, method_name=method_name)
_validate_function_abi(function_abi, method_name=method_name)

View File

@ -15,9 +15,11 @@ from nucypher.characters.lawful import Enrico, Ursula
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
from nucypher.policy.conditions.lingo import (
ConditionLingo,
ConditionVariable,
NotCompoundCondition,
OrCompoundCondition,
ReturnValueTest,
SequentialAccessControlCondition,
)
from nucypher.policy.conditions.time import TimeCondition
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
@ -93,7 +95,14 @@ def condition(test_registry):
)
not_not_condition = NotCompoundCondition(
operand=NotCompoundCondition(operand=and_condition)
operand=NotCompoundCondition(operand=rpc_condition)
)
sequential_condition = SequentialAccessControlCondition(
condition_variables=[
ConditionVariable("rpc", rpc_condition),
ConditionVariable("contract", contract_condition),
]
)
conditions = [
@ -103,6 +112,7 @@ def condition(test_registry):
or_condition,
and_condition,
not_not_condition,
sequential_condition,
]
condition_to_use = random.choice(conditions)

View File

@ -8,7 +8,6 @@ from nucypher.blockchain.eth.agents import (
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.evm import ContractCondition
from nucypher.policy.conditions.lingo import (
AndCompoundCondition,
ConditionLingo,
OrCompoundCondition,
ReturnValueTest,
@ -21,7 +20,6 @@ def condition_providers(testerchain):
providers = {testerchain.client.chain_id: {testerchain.provider}}
return providers
@pytest.fixture()
def compound_lingo(
erc721_evm_condition_balanceof,
@ -35,12 +33,8 @@ def compound_lingo(
operands=[
erc721_evm_condition_balanceof,
time_condition,
AndCompoundCondition(
operands=[
rpc_condition,
erc20_evm_condition_balanceof,
]
),
rpc_condition,
erc20_evm_condition_balanceof,
]
)
)

View File

@ -22,7 +22,6 @@ from nucypher.policy.conditions.evm import (
RPCCondition,
)
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
@ -84,10 +83,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
):
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
new_chain = 23
rpc_condition.chain = new_chain
rpc_condition.execution_call.chain = new_chain
condition_providers = {new_chain: {testerchain.provider}}
with pytest.raises(
InvalidCondition, match=f"can only be evaluated on chain ID {new_chain}"
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -156,10 +155,8 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
}
mocker.patch.object(
rpc_condition, "_check_chain_id", return_value=None
) # skip chain check
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
with pytest.raises(RPCExecutionFailed):
_ = rpc_condition.verify(providers=condition_providers, **context)
@ -186,9 +183,8 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
}
mocker.patch.object(
rpc_condition, "_check_chain_id", return_value=None
) # skip chain check
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
rpc_condition.execution_call, "_configure_provider", my_configure_w3
)
condition_result, call_result = rpc_condition.verify(
providers=condition_providers, **context

View File

@ -1,9 +1,16 @@
from collections import defaultdict
import pytest
from web3 import Web3
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.lingo import (
CompoundAccessControlCondition,
ConditionLingo,
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.utils.policy import make_message_kits
@ -14,22 +21,28 @@ def make_multichain_evm_conditions(bob, chain_ids):
"""This is a helper function to make a set of conditions that are valid on multiple chains."""
operands = list()
for chain_id in chain_ids:
operand = [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": chain_id,
},
{
"conditionType": ConditionType.RPC.value,
"chain": chain_id,
"method": "eth_getBalance",
"parameters": [bob.checksum_address, "latest"],
"returnValueTest": {"comparator": ">=", "value": 10000000000000},
},
]
operands.extend(operand)
compound_and_condition = CompoundAccessControlCondition(
operator="and",
operands=[
TimeCondition(
chain=chain_id,
return_value_test=ReturnValueTest(
comparator=">",
value=0,
),
),
RPCCondition(
chain=chain_id,
method="eth_getBalance",
parameters=[bob.checksum_address, "latest"],
return_value_test=ReturnValueTest(
comparator=">=",
value=10000000000000,
),
),
],
)
operands.append(compound_and_condition.to_dict())
_conditions = {
"version": ConditionLingo.VERSION,
@ -69,7 +82,7 @@ def test_single_retrieve_with_multichain_conditions(
def test_single_decryption_request_with_faulty_rpc_endpoint(
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
monkeymodule, testerchain, enacted_policy, bob, multichain_ursulas, conditions
):
bob.remember_node(multichain_ursulas[0])
bob.start_learning_loop()
@ -80,30 +93,43 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
alice_verifying_key=enacted_policy.publisher_verifying_key,
)
def _mock_configure_provider(*args, **kwargs):
rpc_call_type = args[0]
if isinstance(rpc_call_type, TimeRPCCall):
# time condition call - only RPCCall is made faulty
return testerchain.w3
# rpc condition call
provider = args[1]
w3 = Web3(provider)
return w3
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
calls = defaultdict(int)
original_execute_call = RPCCondition._execute_call
original_execute_call = RPCCall._execute
def faulty_execute_call(*args, **kwargs):
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
nonlocal calls
rpc_call = args[0]
calls[rpc_call.chain] += 1
if (
calls[rpc_call.chain] == 2
and "tester://multichain.0" in rpc_call.provider.endpoint_uri
):
rpc_call_object = args[0]
resolved_parameters = args[2]
calls[rpc_call_object.chain] += 1
if calls[rpc_call_object.chain] % 2 == 0:
# simulate a network error
raise ConnectionError("Something went wrong with the network")
elif calls[rpc_call.chain] == 3:
# check the provider is the fallback
this_uri = rpc_call.provider.endpoint_uri
assert "fallback" in this_uri
return original_execute_call(*args, **kwargs)
RPCCondition._execute_call = faulty_execute_call
# replace w3 object with fake provider, with proper w3 object for actual execution
return original_execute_call(
rpc_call_object, testerchain.w3, resolved_parameters
)
RPCCall._execute = faulty_execute_call
cleartexts = bob.retrieve_and_decrypt(
message_kits=message_kits,
**policy_info_kwargs,
)
assert cleartexts == messages
RPCCondition._execute_call = original_execute_call
RPCCall._execute = original_execute_call

View File

@ -14,7 +14,7 @@ from nucypher.blockchain.eth.agents import (
)
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.evm import RPCCall
from nucypher.utilities.logging import Logger
from tests.constants import (
BONUS_TOKENS_FOR_TESTS,
@ -430,16 +430,11 @@ def taco_child_application_agent(testerchain, test_registry):
#
@pytest.fixture(scope="module")
def mock_rpc_condition(module_mocker, testerchain, monkeymodule):
def configure_mock(condition, provider, *args, **kwargs):
condition.provider = provider
def mock_rpc_condition(testerchain, monkeymodule):
def configure_mock(*args, **kwargs):
return testerchain.w3
monkeymodule.setattr(RPCCondition, "_configure_w3", configure_mock)
configure_spy = module_mocker.spy(RPCCondition, "_configure_w3")
chain_id_check_mock = module_mocker.patch.object(RPCCondition, "_check_chain_id")
return configure_spy, chain_id_check_mock
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
@pytest.fixture(scope="module")

View File

@ -2,6 +2,7 @@ import json
import pytest
from nucypher.policy.conditions.base import AccessControlCondition
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.evm import ContractCondition
from nucypher.policy.conditions.lingo import (
@ -94,3 +95,8 @@ def erc721_evm_condition(test_registry):
]
)
return condition
@pytest.fixture(scope="function")
def mock_skip_schema_validation(mocker):
mocker.patch.object(AccessControlCondition.Schema, "validate", return_value=None)

View File

@ -1,3 +1,4 @@
import random
from unittest.mock import Mock
import pytest
@ -8,8 +9,10 @@ from nucypher.policy.conditions.lingo import (
AndCompoundCondition,
CompoundAccessControlCondition,
ConditionType,
ConditionVariable,
NotCompoundCondition,
OrCompoundCondition,
SequentialAccessControlCondition,
)
@ -63,7 +66,10 @@ def test_invalid_compound_condition(time_condition, rpc_condition):
# no operands
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(operator=operator, operands=[])
_ = CompoundAccessControlCondition(
operator=random.choice(CompoundAccessControlCondition.OPERATORS),
operands=[],
)
# > 1 operand for not operator
with pytest.raises(InvalidCondition):
@ -86,6 +92,21 @@ def test_invalid_compound_condition(time_condition, rpc_condition):
operands=[rpc_condition],
)
# exceeds max operands
operands = list()
for i in range(CompoundAccessControlCondition.MAX_NUM_CONDITIONS + 1):
operands.append(rpc_condition)
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.OR_OPERATOR,
operands=operands,
)
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.AND_OPERATOR,
operands=operands,
)
@pytest.mark.parametrize("operator", CompoundAccessControlCondition.OPERATORS)
def test_compound_condition_schema_validation(operator, time_condition, rpc_condition):
@ -131,7 +152,8 @@ def test_compound_condition_schema_validation(operator, time_condition, rpc_cond
CompoundAccessControlCondition.validate(compound_condition_dict)
def test_and_condition_and_short_circuit(mock_conditions):
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_and_condition_and_short_circuit(mocker, mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
and_condition = AndCompoundCondition(
@ -144,14 +166,14 @@ def test_and_condition_and_short_circuit(mock_conditions):
)
# ensure that all conditions evaluated when all return True
result, value = and_condition.verify()
result, value = and_condition.verify(providers={})
assert result is True
assert len(value) == 4, "all conditions evaluated"
assert value == [1, 2, 3, 4]
# ensure that short circuit happens when 1st condition is false
condition_1.verify.return_value = (False, 1)
result, value = and_condition.verify()
result, value = and_condition.verify(providers={})
assert result is False
assert len(value) == 1, "only one condition evaluated"
assert value == [1]
@ -159,12 +181,13 @@ def test_and_condition_and_short_circuit(mock_conditions):
# short circuit occurs for 3rd entry
condition_1.verify.return_value = (True, 1)
condition_3.verify.return_value = (False, 3)
result, value = and_condition.verify()
result, value = and_condition.verify(providers={})
assert result is False
assert len(value) == 3, "3-of-4 conditions evaluated"
assert value == [1, 2, 3]
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_or_condition_and_short_circuit(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
@ -179,7 +202,7 @@ def test_or_condition_and_short_circuit(mock_conditions):
# ensure that only first condition evaluated when first is True
condition_1.verify.return_value = (True, 1) # short circuit here
result, value = or_condition.verify()
result, value = or_condition.verify(providers={})
assert result is True
assert len(value) == 1, "only first condition needs to be evaluated"
assert value == [1]
@ -189,7 +212,7 @@ def test_or_condition_and_short_circuit(mock_conditions):
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (True, 3) # short circuit here
result, value = or_condition.verify()
result, value = or_condition.verify(providers={})
assert result is True
assert len(value) == 3, "third condition causes short circuit"
assert value == [1, 2, 3]
@ -200,12 +223,13 @@ def test_or_condition_and_short_circuit(mock_conditions):
condition_3.verify.return_value = (False, 3)
condition_4.verify.return_value = (False, 4)
result, value = or_condition.verify()
result, value = or_condition.verify(providers={})
assert result is False
assert len(value) == 4, "all conditions evaluated"
assert value == [1, 2, 3, 4]
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_compound_condition(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
@ -223,7 +247,7 @@ def test_compound_condition(mock_conditions):
)
# all conditions are True
result, value = compound_condition.verify()
result, value = compound_condition.verify(providers={})
assert result is True
assert len(value) == 2, "or_condition and condition_4"
assert value == [[1], 4]
@ -232,7 +256,7 @@ def test_compound_condition(mock_conditions):
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (False, 3)
result, value = compound_condition.verify()
result, value = compound_condition.verify(providers={})
assert result is False
assert len(value) == 1, "or_condition"
assert value == [
@ -243,7 +267,7 @@ def test_compound_condition(mock_conditions):
condition_1.verify.return_value = (True, 1)
condition_4.verify.return_value = (False, 4)
result, value = compound_condition.verify()
result, value = compound_condition.verify(providers={})
assert result is False
assert len(value) == 2, "or_condition and condition_4"
assert value == [
@ -253,7 +277,7 @@ def test_compound_condition(mock_conditions):
# condition_4 is now true
condition_4.verify.return_value = (True, 4)
result, value = compound_condition.verify()
result, value = compound_condition.verify(providers={})
assert result is True
assert len(value) == 2, "or_condition and condition_4"
assert value == [
@ -262,6 +286,57 @@ def test_compound_condition(mock_conditions):
] # or-condition short-circuited because condition_1 was True
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
_ = AndCompoundCondition(
operands=[
OrCompoundCondition(
operands=[
condition_1,
AndCompoundCondition(
operands=[
condition_2,
condition_3,
]
),
]
),
condition_4,
]
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
_ = AndCompoundCondition(
operands=[
OrCompoundCondition(
operands=[
condition_1,
SequentialAccessControlCondition(
condition_variables=[
ConditionVariable("var2", condition_2),
ConditionVariable("var3", condition_3),
]
),
]
),
condition_4,
]
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_compound_condition(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
@ -270,12 +345,8 @@ def test_nested_compound_condition(mock_conditions):
OrCompoundCondition(
operands=[
condition_1,
AndCompoundCondition(
operands=[
condition_2,
condition_3,
]
),
condition_2,
condition_3,
]
),
condition_4,
@ -283,30 +354,43 @@ def test_nested_compound_condition(mock_conditions):
)
# all conditions are True
result, value = nested_compound_condition.verify()
result, value = nested_compound_condition.verify(providers={})
assert result is True
assert len(value) == 2, "or_condition and condition_4"
assert len(value) == 2, "or_condition (condition_1) and condition_4"
assert value == [[1], 4] # or short-circuited since condition_1 is True
# set condition_1 to False so nested and-condition must be evaluated
# set condition_1 to False so condition_2 must be evaluated
condition_1.verify.return_value = (False, 1)
result, value = nested_compound_condition.verify()
result, value = nested_compound_condition.verify(providers={})
assert result is True
assert len(value) == 2, "or_condition and condition_4"
assert len(value) == 2, "or_condition (condition_2) and condition_4"
assert value == [
[1, [2, 3]],
[1, 2],
4,
] # nested and-condition was evaluated and evaluated to True
] # or short-circuited since condition_2 is True
# set condition_3 to False so condition_3 must be evaluated
condition_2.verify.return_value = (False, 2)
result, value = nested_compound_condition.verify(providers={})
assert result is True
assert len(value) == 2, "or_condition (condition_3) and condition_4"
assert value == [
[1, 2, 3],
4,
] # or short-circuited since condition_3 is True
# set condition_4 to False so that overall result flips to False
# (even though condition_3 is still True)
condition_4.verify.return_value = (False, 4)
result, value = nested_compound_condition.verify()
result, value = nested_compound_condition.verify(providers={})
assert result is False
assert len(value) == 2, "or_condition and condition_4"
assert value == [[1, [2, 3]], 4]
assert value == [[1, 2, 3], 4]
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_not_compound_condition(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
@ -316,12 +400,12 @@ def test_not_compound_condition(mock_conditions):
# simple `not`
#
condition_1.verify.return_value = (True, 1)
result, value = not_condition.verify()
result, value = not_condition.verify(providers={})
assert result is False
assert value == 1
condition_1.verify.return_value = (False, 2)
result, value = not_condition.verify()
result, value = not_condition.verify(providers={})
assert result is True
assert value == 2
@ -342,8 +426,8 @@ def test_not_compound_condition(mock_conditions):
]
)
not_condition = NotCompoundCondition(operand=or_condition)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
or_result, or_value = or_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is False
assert result is (not or_result)
assert value == or_value
@ -352,8 +436,8 @@ def test_not_compound_condition(mock_conditions):
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (False, 3)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
or_result, or_value = or_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is True
assert result is (not or_result)
assert value == or_value
@ -362,8 +446,8 @@ def test_not_compound_condition(mock_conditions):
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (True, 3)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
or_result, or_value = or_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is False
assert result is (not or_result)
assert value == or_value
@ -386,8 +470,8 @@ def test_not_compound_condition(mock_conditions):
)
not_condition = NotCompoundCondition(operand=and_condition)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
and_result, and_value = and_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is False
assert result is (not and_result)
assert value == and_value
@ -396,8 +480,8 @@ def test_not_compound_condition(mock_conditions):
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (False, 3)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
and_result, and_value = and_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is True
assert result is (not and_result)
assert value == and_value
@ -406,59 +490,8 @@ def test_not_compound_condition(mock_conditions):
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (False, 3)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
and_result, and_value = and_condition.verify(providers={})
result, value = not_condition.verify(providers={})
assert result is True
assert result is (not and_result)
assert value == and_value
#
# Complex nested `or` and `and` (reused nested compound condition in previous test)
#
nested_compound_condition = AndCompoundCondition(
operands=[
OrCompoundCondition(
operands=[
condition_1,
AndCompoundCondition(
operands=[
condition_2,
condition_3,
]
),
]
),
condition_4,
]
)
not_condition = NotCompoundCondition(operand=nested_compound_condition)
# reset all conditions to True
condition_1.verify.return_value = (True, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (True, 3)
condition_4.verify.return_value = (True, 4)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not nested_result)
assert value == nested_value
# set condition_1 to False so nested and-condition must be evaluated
condition_1.verify.return_value = (False, 1)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not nested_result)
assert value == nested_value
# set condition_4 to False so that overall result flips to False, so `not` is now True
condition_4.verify.return_value = (False, 4)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is True
assert result is (not nested_result)
assert value == nested_value

View File

@ -33,16 +33,15 @@ def lingo_with_compound_conditions(get_random_checksum_address):
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {"value": 0, "comparator": ">"},
},
{
"conditionType": ConditionType.CONTRACT.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "isPolicyActive",
"parameters": [":hrac"],
"returnValueTest": {"comparator": "==", "value": True},
"contractAddress": get_random_checksum_address(),
"functionAbi": {
"type": "function",
@ -59,38 +58,114 @@ def lingo_with_compound_conditions(get_random_checksum_address):
{"name": "", "type": "bool", "internalType": "bool"}
],
},
"returnValueTest": {"comparator": "==", "value": True},
},
# sequential condition
{
"conditionType": ConditionType.COMPOUND.value,
"operator": "or",
"operands": [
"conditionType": ConditionType.SEQUENTIAL.value,
"conditionVariables": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"varName": "timeValue",
"condition": {
# Time
"conditionType": ConditionType.TIME.value,
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {
"value": 0,
"comparator": ">",
},
},
},
{
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [get_random_checksum_address(), "latest"],
"returnValueTest": {
"comparator": ">=",
"value": 10000000000000,
"varName": "rpcValue",
"condition": {
# RPC
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [
get_random_checksum_address(),
"latest",
],
"returnValueTest": {
"comparator": ">=",
"value": 10000000000000,
},
},
},
{
"varName": "contractValue",
"condition": {
# Contract
"conditionType": ConditionType.CONTRACT.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "isPolicyActive",
"parameters": [":hrac"],
"contractAddress": get_random_checksum_address(),
"functionAbi": {
"type": "function",
"name": "isPolicyActive",
"stateMutability": "view",
"inputs": [
{
"name": "_policyID",
"type": "bytes16",
"internalType": "bytes16",
}
],
"outputs": [
{
"name": "",
"type": "bool",
"internalType": "bool",
}
],
},
"returnValueTest": {
"comparator": "==",
"value": True,
},
},
},
{
"varName": "jsonValue",
"condition": {
# JSON API
"conditionType": ConditionType.JSONAPI.value,
"endpoint": "https://api.example.com/data",
"query": "$.store.book[0].price",
"parameters": {
"ids": "ethereum",
"vs_currencies": "usd",
},
"returnValueTest": {
"comparator": "==",
"value": 2,
},
},
},
],
},
{
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [get_random_checksum_address(), "latest"],
"returnValueTest": {
"comparator": ">=",
"value": 10000000000000,
},
},
{
"conditionType": ConditionType.COMPOUND.value,
"operator": "not",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {"value": 0, "comparator": ">"},
},
],
},
@ -98,7 +173,6 @@ def lingo_with_compound_conditions(get_random_checksum_address):
},
}
def test_invalid_condition():
# no version or condition
data = dict()
@ -127,6 +201,9 @@ def test_invalid_condition():
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_json(json.dumps(data))
def test_invalid_compound_condition():
# invalid operator
invalid_operator = {
"version": ConditionLingo.VERSION,
@ -275,7 +352,7 @@ def test_condition_lingo_to_from_json(lingo_with_compound_conditions):
assert clingo_from_json.to_dict() == lingo_with_compound_conditions
def test_condition_lingo_repr(lingo_with_compound_conditions):
def test_compound_condition_lingo_repr(lingo_with_compound_conditions):
clingo = ConditionLingo.from_dict(lingo_with_compound_conditions)
clingo_string = f"{clingo}"
assert f"{clingo.__class__.__name__}" in clingo_string

View File

@ -11,7 +11,7 @@ from nucypher.policy.conditions.context import (
_resolve_user_address,
get_context_value,
is_context_variable,
resolve_any_context_variables,
resolve_parameter_context_variables,
)
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
@ -86,9 +86,8 @@ def test_resolve_any_context_variables():
params, resolved_params = params_with_resolution
value, resolved_value = value_with_resolution
return_value_test = ReturnValueTest(comparator="==", value=value)
resolved_parameters, resolved_return_value = resolve_any_context_variables(
[params], return_value_test, **CONTEXT
)
resolved_parameters = resolve_parameter_context_variables([params], **CONTEXT)
resolved_return_value = return_value_test.with_resolved_context(**CONTEXT)
assert resolved_parameters == [resolved_params]
assert resolved_return_value.comparator == return_value_test.comparator
assert resolved_return_value.index == return_value_test.index

View File

@ -11,7 +11,7 @@ from hexbytes import HexBytes
from marshmallow import post_load
from web3.providers import BaseProvider
from nucypher.policy.conditions.evm import ContractCondition
from nucypher.policy.conditions.evm import ContractCall, ContractCondition
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
@ -44,6 +44,17 @@ CONTRACT_CONDITION = {
class FakeExecutionContractCondition(ContractCondition):
class FakeRPCCall(ContractCall):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.execution_return_value = None
def set_execution_return_value(self, value: Any):
self.execution_return_value = value
def execute(self, providers: Dict, **context) -> Any:
return self.execution_return_value
class Schema(ContractCondition.Schema):
@post_load
def make(self, data, **kwargs):
@ -51,16 +62,12 @@ class FakeExecutionContractCondition(ContractCondition):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.execution_return_value = None
def _create_execution_call(self, *args, **kwargs) -> ContractCall:
return self.FakeRPCCall(*args, **kwargs)
def set_execution_return_value(self, value: Any):
self.execution_return_value = value
def _execute_call(self, parameters: List[Any]) -> Any:
return self.execution_return_value
def _configure_provider(self, provider: BaseProvider):
return
self.execution_call.set_execution_return_value(value)
@pytest.fixture(scope="function")

View File

@ -7,7 +7,6 @@ from marshmallow import ValidationError
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
InvalidCondition,
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest
from nucypher.policy.conditions.offchain import (
@ -56,7 +55,7 @@ def test_json_api_condition_invalid_type():
def test_https_enforcement():
with pytest.raises(InvalidConditionLingo) as excinfo:
with pytest.raises(InvalidCondition) as excinfo:
JsonApiCondition(
endpoint="http://api.example.com/data",
query="$.store.book[0].price",
@ -88,7 +87,7 @@ def test_json_api_condition_fetch(mocker):
query="$.store.book[0].title",
return_value_test=ReturnValueTest("==", "'Test Title'"),
)
response = condition.fetch()
response = condition.execution_call._fetch()
assert response.status_code == 200
assert response.json() == {"store": {"book": [{"title": "Test Title"}]}}
@ -104,7 +103,7 @@ def test_json_api_condition_fetch_failure(mocker):
return_value_test=ReturnValueTest("==", 1),
)
with pytest.raises(InvalidCondition) as excinfo:
condition.fetch()
condition.execution_call._fetch()
assert "Failed to fetch endpoint" in str(excinfo.value)
@ -224,7 +223,7 @@ def test_json_api_condition_from_lingo_expression():
},
}
cls = ConditionLingo.resolve_condition_class(lingo_dict, version=1.0)
cls = ConditionLingo.resolve_condition_class(lingo_dict, version=1)
assert cls == JsonApiCondition
lingo_json = json.dumps(lingo_dict)

View File

@ -0,0 +1,226 @@
import pytest
from web3.exceptions import Web3Exception
from nucypher.policy.conditions.base import (
AccessControlCondition,
)
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import (
ConditionType,
ConditionVariable,
OrCompoundCondition,
SequentialAccessControlCondition,
)
@pytest.fixture(scope="function")
def mock_condition_variables(mocker):
cond_1 = mocker.Mock(spec=AccessControlCondition)
cond_1.verify.return_value = (True, 1)
cond_1.to_dict.return_value = {"value": 1}
var_1 = ConditionVariable(var_name="var1", condition=cond_1)
cond_2 = mocker.Mock(spec=AccessControlCondition)
cond_2.verify.return_value = (True, 2)
cond_2.to_dict.return_value = {"value": 2}
var_2 = ConditionVariable(var_name="var2", condition=cond_2)
cond_3 = mocker.Mock(spec=AccessControlCondition)
cond_3.verify.return_value = (True, 3)
cond_3.to_dict.return_value = {"value": 3}
var_3 = ConditionVariable(var_name="var3", condition=cond_3)
cond_4 = mocker.Mock(spec=AccessControlCondition)
cond_4.verify.return_value = (True, 4)
cond_4.to_dict.return_value = {"value": 4}
var_4 = ConditionVariable(var_name="var4", condition=cond_4)
return var_1, var_2, var_3, var_4
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_invalid_sequential_condition(mock_condition_variables):
var_1, *_ = mock_condition_variables
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.SEQUENTIAL.value):
_ = SequentialAccessControlCondition(
condition_type=ConditionType.TIME.value,
condition_variables=list(mock_condition_variables),
)
# no variables
with pytest.raises(InvalidCondition, match="At least two conditions"):
_ = SequentialAccessControlCondition(
condition_variables=[],
)
# only one variable
with pytest.raises(InvalidCondition, match="At least two conditions"):
_ = SequentialAccessControlCondition(
condition_variables=[var_1],
)
# too many variables
too_many_variables = list(mock_condition_variables)
too_many_variables.extend(mock_condition_variables) # duplicate list length
assert len(too_many_variables) > SequentialAccessControlCondition.MAX_NUM_CONDITIONS
with pytest.raises(InvalidCondition, match="Maximum of"):
_ = SequentialAccessControlCondition(
condition_variables=too_many_variables,
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_sequential_condition_too_many_nested_levels(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
_ = (
SequentialAccessControlCondition(
condition_variables=[
var_1,
ConditionVariable(
"seq_1",
SequentialAccessControlCondition(
condition_variables=[
var_2,
ConditionVariable(
"seq_2",
SequentialAccessControlCondition(
condition_variables=[
var_3,
var_4,
],
),
),
],
),
),
]
),
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_nested_compound_condition_too_many_nested_levels(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
with pytest.raises(
InvalidCondition, match="nested levels of multi-conditions are allowed"
):
_ = SequentialAccessControlCondition(
condition_variables=[
ConditionVariable(
"var1",
OrCompoundCondition(
operands=[
var_1.condition,
SequentialAccessControlCondition(
condition_variables=[
var_2,
var_3,
]
),
]
),
),
var_4,
],
)
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_sequential_condition(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
var_1.condition.verify.return_value = (True, 1)
var_2.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"] * 2,
)
var_3.condition.verify = lambda providers, **context: (
True,
context[f":{var_2.var_name}"] * 3,
)
var_4.condition.verify = lambda providers, **context: (
True,
context[f":{var_3.var_name}"] * 4,
)
sequential_condition = SequentialAccessControlCondition(
condition_variables=[var_1, var_2, var_3, var_4],
)
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
assert result is True
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
# only a copy of the context is modified internally
assert len(original_context) == 0, "original context remains unchanged"
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
mock_condition_variables,
):
var_1, var_2, var_3, var_4 = mock_condition_variables
var_1.condition.verify.return_value = (True, 1)
var_2.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"] + 1,
)
var_3.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"] + context[f":{var_2.var_name}"] + 1,
)
var_4.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"]
+ context[f":{var_2.var_name}"]
+ context[f":{var_3.var_name}"]
+ 1,
)
sequential_condition = SequentialAccessControlCondition(
condition_variables=[var_1, var_2, var_3, var_4],
)
expected_var_1_value = 1
expected_var_2_value = expected_var_1_value + 1
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
assert result is True
assert value == [
expected_var_1_value,
expected_var_2_value,
expected_var_3_value,
(expected_var_1_value + expected_var_2_value + expected_var_3_value + 1),
]
# only a copy of the context is modified internally
assert len(original_context) == 0, "original context remains unchanged"
@pytest.mark.usefixtures("mock_skip_schema_validation")
def test_sequential_condition_a_call_fails(mock_condition_variables):
var_1, var_2, var_3, var_4 = mock_condition_variables
var_4.condition.verify.side_effect = Web3Exception
sequential_condition = SequentialAccessControlCondition(
condition_variables=[var_1, var_2, var_3, var_4],
)
with pytest.raises(Web3Exception):
_ = sequential_condition.verify(providers={})

View File

@ -2,7 +2,7 @@ import pytest
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.time import TimeCondition
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
from tests.constants import TESTERCHAIN_CHAIN_ID
@ -13,7 +13,7 @@ def test_invalid_time_condition():
condition_type=ConditionType.COMPOUND.value,
return_value_test=ReturnValueTest(">", 0),
chain=TESTERCHAIN_CHAIN_ID,
method=TimeCondition.METHOD,
method=TimeRPCCall.METHOD,
)
# invalid method
@ -29,7 +29,7 @@ def test_invalid_time_condition():
_ = TimeCondition(
return_value_test=ReturnValueTest(">", 0),
chain=90210, # Beverly Hills Chain :)
method=TimeCondition.METHOD,
method=TimeRPCCall.METHOD,
)