mirror of https://github.com/nucypher/nucypher.git
commit
6a54f82f0f
|
@ -0,0 +1 @@
|
|||
Support for executing multiple conditions sequentially, where the outcome of one condition can be used as input for another.
|
|
@ -1,14 +1,15 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from marshmallow import Schema, ValidationError
|
||||
from marshmallow import Schema, ValidationError, fields
|
||||
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
|
||||
|
||||
class _Serializable:
|
||||
|
@ -51,24 +52,30 @@ class _Serializable:
|
|||
|
||||
|
||||
class AccessControlCondition(_Serializable, ABC):
|
||||
CONDITION_TYPE = NotImplemented
|
||||
|
||||
class Schema(Schema):
|
||||
name = NotImplemented
|
||||
|
||||
def __init__(self):
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
name = fields.Str(required=False)
|
||||
condition_type = NotImplemented
|
||||
|
||||
def __init__(self, condition_type: str, name: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
|
||||
# validate inputs using marshmallow schema
|
||||
schema = self.Schema()
|
||||
errors = schema.validate(self.to_dict())
|
||||
if errors:
|
||||
raise InvalidConditionLingo(errors)
|
||||
self.validate(self.to_dict())
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
|
||||
"""Returns the boolean result of the evaluation and the returned value in a two-tuple."""
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data: Dict) -> None:
|
||||
|
@ -89,3 +96,48 @@ class AccessControlCondition(_Serializable, ABC):
|
|||
return super().from_json(data)
|
||||
except ValidationError as e:
|
||||
raise InvalidConditionLingo(f"Invalid condition grammar: {e}")
|
||||
|
||||
|
||||
class ExecutionCall(ABC):
|
||||
@abstractmethod
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiConditionAccessControl(AccessControlCondition):
|
||||
MAX_NUM_CONDITIONS = 5
|
||||
MAX_MULTI_CONDITION_NESTED_LEVEL = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._validate_multi_condition_nesting(conditions=self.conditions)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def conditions(self) -> List[AccessControlCondition]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_multi_condition_nesting(
|
||||
cls,
|
||||
conditions: List[AccessControlCondition],
|
||||
current_level: int = 1,
|
||||
):
|
||||
if len(conditions) > cls.MAX_NUM_CONDITIONS:
|
||||
raise InvalidCondition(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
|
||||
)
|
||||
|
||||
for condition in conditions:
|
||||
if not isinstance(condition, MultiConditionAccessControl):
|
||||
continue
|
||||
|
||||
level = current_level + 1
|
||||
if level > cls.MAX_MULTI_CONDITION_NESTED_LEVEL:
|
||||
raise InvalidCondition(
|
||||
f"Only {cls.MAX_MULTI_CONDITION_NESTED_LEVEL} nested levels of multi-conditions are allowed"
|
||||
)
|
||||
cls._validate_multi_condition_nesting(
|
||||
conditions=condition.conditions,
|
||||
current_level=level,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from functools import partial
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
|
@ -125,9 +125,11 @@ def _resolve_context_variable(param: Union[Any, List[Any]], **context):
|
|||
return param
|
||||
|
||||
|
||||
def resolve_any_context_variables(parameters: List[Any], return_value_test, **context):
|
||||
processed_parameters = [
|
||||
_resolve_context_variable(param, **context) for param in parameters
|
||||
]
|
||||
processed_return_value_test = return_value_test.with_resolved_context(**context)
|
||||
return processed_parameters, processed_return_value_test
|
||||
def resolve_parameter_context_variables(parameters: Optional[List[Any]], **context):
|
||||
if not parameters:
|
||||
processed_parameters = [] # produce empty list
|
||||
else:
|
||||
processed_parameters = [
|
||||
_resolve_context_variable(param, **context) for param in parameters
|
||||
]
|
||||
return processed_parameters
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import (
|
|||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
|
||||
from marshmallow.validate import OneOf
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
@ -19,22 +18,29 @@ from web3.providers import BaseProvider
|
|||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
|
||||
from nucypher.policy.conditions.base import AccessControlCondition
|
||||
from nucypher.policy.conditions.base import (
|
||||
ExecutionCall,
|
||||
)
|
||||
from nucypher.policy.conditions.context import (
|
||||
is_context_variable,
|
||||
resolve_any_context_variables,
|
||||
resolve_parameter_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionType,
|
||||
ExecutionCallAccessControlCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import camel_case_to_snake
|
||||
from nucypher.policy.conditions.validation import (
|
||||
_align_comparator_value_with_abi,
|
||||
_get_abi_types,
|
||||
_validate_condition_abi,
|
||||
_validate_contract_call_abi,
|
||||
_validate_multiple_output_types,
|
||||
_validate_single_output_type,
|
||||
)
|
||||
|
@ -102,94 +108,70 @@ def _validate_chain(chain: int) -> None:
|
|||
)
|
||||
|
||||
|
||||
class RPCCondition(AccessControlCondition):
|
||||
ETH_PREFIX = "eth_"
|
||||
class RPCCall(ExecutionCall):
|
||||
LOG = logging.Logger(__name__)
|
||||
|
||||
ALLOWED_METHODS = {
|
||||
# RPC
|
||||
"eth_getBalance": int,
|
||||
} # TODO other allowed methods (tDEC #64)
|
||||
LOG = logging.Logger(__name__)
|
||||
CONDITION_TYPE = ConditionType.RPC.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
name = fields.Str(required=False)
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.RPC.value), required=True
|
||||
)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return RPCCondition(**data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r = f"{self.__class__.__name__}(function={self.method}, chain={self.chain})"
|
||||
return r
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: int,
|
||||
method: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
# Validate input
|
||||
# TODO: Additional validation (function is valid for ABI, RVT validity, standard contract name validity, etc.)
|
||||
_validate_chain(chain=chain)
|
||||
|
||||
# internal
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
self.chain = chain
|
||||
self.provider: Optional[BaseProvider] = None # set in _configure_provider
|
||||
self.method = self._validate_method(method=method)
|
||||
|
||||
# test
|
||||
# should not be set to None - we do list unpacking so cannot be None; use empty list
|
||||
self.parameters = parameters or []
|
||||
self.return_value_test = return_value_test # output
|
||||
|
||||
self._validate_expected_return_type()
|
||||
|
||||
super().__init__()
|
||||
self.parameters = parameters or None
|
||||
|
||||
def _validate_method(self, method):
|
||||
if not method:
|
||||
raise InvalidCondition("Undefined method name")
|
||||
raise ValueError("Undefined method name")
|
||||
|
||||
if method not in self.ALLOWED_METHODS:
|
||||
raise InvalidCondition(
|
||||
raise ValueError(
|
||||
f"'{method}' is not a permitted RPC endpoint for condition evaluation."
|
||||
)
|
||||
return method
|
||||
|
||||
def _validate_expected_return_type(self):
|
||||
expected_return_type = self.ALLOWED_METHODS[self.method]
|
||||
comparator_value = self.return_value_test.value
|
||||
if is_context_variable(comparator_value):
|
||||
return
|
||||
def _get_web3_py_function(self, w3: Web3, rpc_method: str):
|
||||
web3_py_method = camel_case_to_snake(rpc_method)
|
||||
rpc_function = getattr(
|
||||
w3.eth, web3_py_method
|
||||
) # bind contract function (only exposes the eth API)
|
||||
return rpc_function
|
||||
|
||||
if not isinstance(self.return_value_test.value, expected_return_type):
|
||||
raise InvalidCondition(
|
||||
f"Return value comparison for '{self.method}' call output "
|
||||
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self, w3: Web3) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise NoConnectionToChain(
|
||||
chain=self.chain,
|
||||
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}",
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(w3)
|
||||
return w3
|
||||
|
||||
def _next_endpoint(
|
||||
self, providers: Dict[int, Set[HTTPProvider]]
|
||||
) -> Iterator[HTTPProvider]:
|
||||
|
@ -208,47 +190,99 @@ class RPCCondition(AccessControlCondition):
|
|||
# each endpoint here to see if it's alive and only yield it if it is.
|
||||
yield provider
|
||||
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
self.provider = provider
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
|
||||
resolved_parameters = resolve_parameter_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = self.w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise InvalidCondition(
|
||||
f"This condition can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}"
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
latest_error = ""
|
||||
for provider in endpoints:
|
||||
w3 = self._configure_provider(provider)
|
||||
try:
|
||||
result = self._execute(w3, resolved_parameters)
|
||||
break
|
||||
except RequiredContextVariable:
|
||||
raise
|
||||
except Exception as e:
|
||||
latest_error = f"RPC call '{self.method}' failed: {e}"
|
||||
self.LOG.warn(f"{latest_error}, attempting to try next endpoint.")
|
||||
# Something went wrong. Try the next endpoint.
|
||||
continue
|
||||
else:
|
||||
# Fuck.
|
||||
raise RPCExecutionFailed(
|
||||
f"RPC call '{self.method}' failed; latest error - {latest_error}"
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
self.w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id()
|
||||
return provider
|
||||
return result
|
||||
|
||||
def _get_web3_py_function(self, rpc_method: str):
|
||||
web3_py_method = camel_case_to_snake(rpc_method)
|
||||
rpc_function = getattr(
|
||||
self.w3.eth, web3_py_method
|
||||
) # bind contract function (only exposes the eth API)
|
||||
return rpc_function
|
||||
|
||||
def _execute_call(self, parameters: List[Any]) -> Any:
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
rpc_endpoint_, rpc_method = self.method.split("_", 1)
|
||||
rpc_function = self._get_web3_py_function(rpc_method)
|
||||
rpc_result = rpc_function(*parameters) # RPC read
|
||||
rpc_function = self._get_web3_py_function(w3, rpc_method)
|
||||
rpc_result = rpc_function(*resolved_parameters) # RPC read
|
||||
return rpc_result
|
||||
|
||||
|
||||
class RPCCondition(ExecutionCallAccessControlCondition):
|
||||
CONDITION_TYPE = ConditionType.RPC.value
|
||||
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.RPC.value), required=True
|
||||
)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=validate.OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return RPCCondition(**data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r = f"{self.__class__.__name__}(function={self.method}, chain={self.chain})"
|
||||
return r
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
self._validate_expected_return_type()
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return RPCCall(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
return self.execution_call.method
|
||||
|
||||
@property
|
||||
def chain(self):
|
||||
return self.execution_call.chain
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
return self.execution_call.parameters
|
||||
|
||||
def _validate_expected_return_type(self):
|
||||
expected_return_type = RPCCall.ALLOWED_METHODS[self.method]
|
||||
comparator_value = self.return_value_test.value
|
||||
if is_context_variable(comparator_value):
|
||||
return
|
||||
|
||||
if not isinstance(self.return_value_test.value, expected_return_type):
|
||||
raise InvalidCondition(
|
||||
f"Return value comparison for '{self.method}' call output "
|
||||
f"should be '{expected_return_type}' and not '{type(comparator_value)}'."
|
||||
)
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
|
@ -257,35 +291,77 @@ class RPCCondition(AccessControlCondition):
|
|||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
) -> Tuple[bool, Any]:
|
||||
"""
|
||||
Verifies the onchain condition is met by performing a
|
||||
read operation and evaluating the return value test.
|
||||
"""
|
||||
parameters, return_value_test = resolve_any_context_variables(
|
||||
self.parameters, self.return_value_test, **context
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**context
|
||||
)
|
||||
return_value_test = self._align_comparator_value_with_abi(return_value_test)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
for provider in endpoints:
|
||||
self._configure_provider(provider=provider)
|
||||
try:
|
||||
result = self._execute_call(parameters=parameters)
|
||||
break
|
||||
except Exception as e:
|
||||
self.LOG.warn(
|
||||
f"RPC call '{self.method}' failed: {e}, attempting to try next endpoint."
|
||||
)
|
||||
# Something went wrong. Try the next endpoint.
|
||||
continue
|
||||
else:
|
||||
# Fuck.
|
||||
raise RPCExecutionFailed(f"Contract call '{self.method}' failed.")
|
||||
return_value_test = self._align_comparator_value_with_abi(
|
||||
resolved_return_value_test
|
||||
)
|
||||
result = self.execution_call.execute(providers=providers, **context)
|
||||
|
||||
eval_result = return_value_test.eval(result) # test
|
||||
return eval_result, result
|
||||
|
||||
|
||||
class ContractCall(RPCCall):
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
contract_address: ChecksumAddress,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not method:
|
||||
raise ValueError("Undefined method name")
|
||||
|
||||
_validate_contract_call_abi(
|
||||
standard_contract_type, function_abi, method_name=method
|
||||
)
|
||||
|
||||
# preprocessing
|
||||
contract_address = to_checksum_address(contract_address)
|
||||
self.contract_address = contract_address
|
||||
self.standard_contract_type = standard_contract_type
|
||||
self.function_abi = function_abi
|
||||
|
||||
super().__init__(method=method, *args, **kwargs)
|
||||
self.contract_function = self._get_unbound_contract_function()
|
||||
|
||||
def _validate_method(self, method):
|
||||
return method
|
||||
|
||||
def _get_unbound_contract_function(self) -> ContractFunction:
|
||||
"""Gets an unbound contract function to evaluate for this condition"""
|
||||
w3 = Web3()
|
||||
function_abi = _resolve_abi(
|
||||
w3=w3,
|
||||
standard_contract_type=self.standard_contract_type,
|
||||
method=self.method,
|
||||
function_abi=self.function_abi,
|
||||
)
|
||||
try:
|
||||
contract = w3.eth.contract(
|
||||
address=self.contract_address, abi=[function_abi]
|
||||
)
|
||||
contract_function = getattr(contract.functions, self.method)
|
||||
return contract_function
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to find contract function, '{self.method}', for condition: {e}"
|
||||
)
|
||||
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
self.contract_function.w3 = w3
|
||||
bound_contract_function = self.contract_function(
|
||||
*resolved_parameters
|
||||
) # bind contract function
|
||||
contract_result = bound_contract_function.call() # onchain read
|
||||
return contract_result
|
||||
|
||||
|
||||
class ContractCondition(RPCCondition):
|
||||
CONDITION_TYPE = ConditionType.CONTRACT.value
|
||||
|
||||
|
@ -293,8 +369,8 @@ class ContractCondition(RPCCondition):
|
|||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
|
||||
)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
contract_address = fields.Str(required=True)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
function_abi = fields.Dict(required=False)
|
||||
|
||||
@post_load
|
||||
|
@ -306,7 +382,7 @@ class ContractCondition(RPCCondition):
|
|||
standard_contract_type = data.get("standard_contract_type")
|
||||
function_abi = data.get("function_abi")
|
||||
try:
|
||||
_validate_condition_abi(
|
||||
_validate_contract_call_abi(
|
||||
standard_contract_type, function_abi, method_name=data.get("method")
|
||||
)
|
||||
except ValueError as e:
|
||||
|
@ -314,64 +390,37 @@ class ContractCondition(RPCCondition):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
contract_address: ChecksumAddress,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not method:
|
||||
raise InvalidCondition("Undefined method name")
|
||||
try:
|
||||
_validate_condition_abi(
|
||||
standard_contract_type, function_abi, method_name=method
|
||||
)
|
||||
except ValueError as e:
|
||||
raise InvalidCondition(str(e))
|
||||
|
||||
self.method = method
|
||||
self.w3 = Web3() # used to instantiate contract function without a provider
|
||||
|
||||
# preprocessing
|
||||
contract_address = to_checksum_address(contract_address)
|
||||
|
||||
# spec
|
||||
self.contract_address = contract_address
|
||||
self.condition_type = condition_type
|
||||
self.standard_contract_type = standard_contract_type
|
||||
self.function_abi = function_abi
|
||||
|
||||
self.contract_function = self._get_unbound_contract_function()
|
||||
|
||||
# call to super must be at the end for proper validation
|
||||
super().__init__(condition_type=condition_type, method=method, *args, **kwargs)
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
def _validate_method(self, method):
|
||||
# overrides validate method used by rpc superclass
|
||||
return method
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return ContractCall(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def function_abi(self):
|
||||
return self.execution_call.function_abi
|
||||
|
||||
@property
|
||||
def standard_contract_type(self):
|
||||
return self.execution_call.standard_contract_type
|
||||
|
||||
@property
|
||||
def contract_function(self):
|
||||
return self.execution_call.contract_function
|
||||
|
||||
@property
|
||||
def contract_address(self):
|
||||
return self.execution_call.contract_address
|
||||
|
||||
def _validate_expected_return_type(self) -> None:
|
||||
output_abi_types = _get_abi_types(self.contract_function.contract_abi[0])
|
||||
comparator_value = self.return_value_test.value
|
||||
comparator_index = self.return_value_test.index
|
||||
index_string = (
|
||||
f"@index={comparator_index}" if comparator_index is not None else ""
|
||||
_validate_contract_function_expected_return_type(
|
||||
contract_function=self.contract_function,
|
||||
return_value_test=self.return_value_test,
|
||||
)
|
||||
failure_message = (
|
||||
f"Invalid return value comparison type '{type(comparator_value)}' for "
|
||||
f"'{self.contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
|
||||
)
|
||||
|
||||
if len(output_abi_types) == 1:
|
||||
_validate_single_output_type(
|
||||
output_abi_types[0], comparator_value, comparator_index, failure_message
|
||||
)
|
||||
else:
|
||||
_validate_multiple_output_types(
|
||||
output_abi_types, comparator_value, comparator_index, failure_message
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r = (
|
||||
|
@ -381,37 +430,6 @@ class ContractCondition(RPCCondition):
|
|||
)
|
||||
return r
|
||||
|
||||
def _configure_provider(self, *args, **kwargs):
|
||||
super()._configure_provider(*args, **kwargs)
|
||||
self.contract_function.w3 = self.w3
|
||||
|
||||
def _get_unbound_contract_function(self) -> ContractFunction:
|
||||
"""Gets an unbound contract function to evaluate for this condition"""
|
||||
function_abi = _resolve_abi(
|
||||
w3=self.w3,
|
||||
standard_contract_type=self.standard_contract_type,
|
||||
method=self.method,
|
||||
function_abi=self.function_abi,
|
||||
)
|
||||
try:
|
||||
contract = self.w3.eth.contract(
|
||||
address=self.contract_address, abi=[function_abi]
|
||||
)
|
||||
contract_function = getattr(contract.functions, self.method)
|
||||
return contract_function
|
||||
except Exception as e:
|
||||
raise InvalidCondition(
|
||||
f"Unable to find contract function, '{self.method}', for condition: {e}"
|
||||
)
|
||||
|
||||
def _execute_call(self, parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
bound_contract_function = self.contract_function(
|
||||
*parameters
|
||||
) # bind contract function
|
||||
contract_result = bound_contract_function.call() # onchain read
|
||||
return contract_result
|
||||
|
||||
def _align_comparator_value_with_abi(
|
||||
self, return_value_test: ReturnValueTest
|
||||
) -> ReturnValueTest:
|
||||
|
@ -419,3 +437,25 @@ class ContractCondition(RPCCondition):
|
|||
abi=self.contract_function.contract_abi[0],
|
||||
return_value_test=return_value_test,
|
||||
)
|
||||
|
||||
|
||||
def _validate_contract_function_expected_return_type(
|
||||
contract_function: ContractFunction, return_value_test: ReturnValueTest
|
||||
) -> None:
|
||||
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
|
||||
comparator_value = return_value_test.value
|
||||
comparator_index = return_value_test.index
|
||||
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
|
||||
failure_message = (
|
||||
f"Invalid return value comparison type '{type(comparator_value)}' for "
|
||||
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
|
||||
)
|
||||
|
||||
if len(output_abi_types) == 1:
|
||||
_validate_single_output_type(
|
||||
output_abi_types[0], comparator_value, comparator_index, failure_message
|
||||
)
|
||||
else:
|
||||
_validate_multiple_output_types(
|
||||
output_abi_types, comparator_value, comparator_index, failure_message
|
||||
)
|
||||
|
|
|
@ -7,9 +7,9 @@ class InvalidConditionLingo(Exception):
|
|||
class NoConnectionToChain(RuntimeError):
|
||||
"""Raised when a node does not have an associated provider for a chain."""
|
||||
|
||||
def __init__(self, chain: int):
|
||||
def __init__(self, chain: int, message: str = None):
|
||||
self.chain = chain
|
||||
message = f"No connection to chain ID {chain}"
|
||||
message = message or f"No connection to chain ID {chain}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,10 @@ import ast
|
|||
import base64
|
||||
import json
|
||||
import operator as pyoperator
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from marshmallow import (
|
||||
|
@ -19,8 +20,14 @@ from marshmallow import (
|
|||
)
|
||||
from marshmallow.validate import OneOf, Range
|
||||
from packaging.version import parse as parse_version
|
||||
from web3 import HTTPProvider
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
|
||||
from nucypher.policy.conditions.base import (
|
||||
AccessControlCondition,
|
||||
ExecutionCall,
|
||||
MultiConditionAccessControl,
|
||||
_Serializable,
|
||||
)
|
||||
from nucypher.policy.conditions.context import (
|
||||
_resolve_context_variable,
|
||||
is_context_variable,
|
||||
|
@ -52,19 +59,8 @@ class _ConditionField(fields.Dict):
|
|||
instance = condition_class.from_dict(condition_data)
|
||||
return instance
|
||||
|
||||
#
|
||||
# CONDITION = BASE_CONDITION | COMPOUND_CONDITION
|
||||
#
|
||||
# BASE_CONDITION = {
|
||||
# // ..
|
||||
# }
|
||||
#
|
||||
# COMPOUND_CONDITION = {
|
||||
# "operator": OPERATOR,
|
||||
# "operands": [CONDITION*]
|
||||
# }
|
||||
|
||||
|
||||
# CONDITION = TIME | CONTRACT | RPC | JSON_API | COMPOUND | SEQUENTIAL
|
||||
class ConditionType(Enum):
|
||||
"""
|
||||
Defines the types of conditions that can be evaluated.
|
||||
|
@ -75,13 +71,27 @@ class ConditionType(Enum):
|
|||
RPC = "rpc"
|
||||
JSONAPI = "json-api"
|
||||
COMPOUND = "compound"
|
||||
SEQUENTIAL = "sequential"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> List[str]:
|
||||
return [condition.value for condition in cls]
|
||||
|
||||
|
||||
class CompoundAccessControlCondition(AccessControlCondition):
|
||||
class CompoundAccessControlCondition(MultiConditionAccessControl):
|
||||
"""
|
||||
A combination of two or more conditions connected by logical operators such as AND, OR, NOT.
|
||||
|
||||
CompoundCondition grammar:
|
||||
OPERATOR = AND | OR | NOT
|
||||
|
||||
COMPOUND_CONDITION = {
|
||||
"name": ... (Optional)
|
||||
"conditionType": "compound",
|
||||
"operator": OPERATOR,
|
||||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
AND_OPERATOR = "and"
|
||||
OR_OPERATOR = "or"
|
||||
NOT_OPERATOR = "not"
|
||||
|
@ -93,28 +103,32 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
def _validate_operator_and_operands(
|
||||
cls,
|
||||
operator: str,
|
||||
operands: List,
|
||||
operands: List[Union[Dict, AccessControlCondition]],
|
||||
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
|
||||
):
|
||||
if operator not in cls.OPERATORS:
|
||||
raise exception_class(f"{operator} is not a valid operator")
|
||||
|
||||
num_operands = len(operands)
|
||||
if operator == cls.NOT_OPERATOR:
|
||||
if len(operands) != 1:
|
||||
if num_operands != 1:
|
||||
raise exception_class(
|
||||
f"Only 1 operand permitted for '{operator}' compound condition"
|
||||
)
|
||||
elif len(operands) < 2:
|
||||
elif num_operands < 2:
|
||||
raise exception_class(
|
||||
f"Minimum of 2 operand needed for '{operator}' compound condition"
|
||||
)
|
||||
elif num_operands > cls.MAX_NUM_CONDITIONS:
|
||||
raise exception_class(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} operands allowed for '{operator}' compound condition"
|
||||
)
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
|
||||
class Schema(AccessControlCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
operator = fields.Str(required=True)
|
||||
operands = fields.List(_ConditionField, required=True)
|
||||
|
||||
|
@ -147,20 +161,16 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
|
||||
self._validate_operator_and_operands(operator, operands, InvalidCondition)
|
||||
|
||||
self.condition_type = condition_type
|
||||
self.operator = operator
|
||||
self.operands = operands
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
self.id = md5(bytes(self)).hexdigest()[:6]
|
||||
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Operator={self.operator} (NumOperands={len(self.operands)}), id={self.id})"
|
||||
|
||||
|
@ -186,6 +196,10 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
|
||||
return overall_result, values
|
||||
|
||||
@property
|
||||
def conditions(self):
|
||||
return self.operands
|
||||
|
||||
|
||||
class OrCompoundCondition(CompoundAccessControlCondition):
|
||||
def __init__(self, operands: List[AccessControlCondition]):
|
||||
|
@ -212,6 +226,126 @@ _COMPARATOR_FUNCTIONS = {
|
|||
}
|
||||
|
||||
|
||||
class ConditionVariable(_Serializable):
|
||||
class Schema(CamelCaseSchema):
|
||||
var_name = fields.Str(required=True) # TODO: should this be required?
|
||||
condition = _ConditionField(required=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return ConditionVariable(**data)
|
||||
|
||||
def __init__(self, var_name: str, condition: AccessControlCondition):
|
||||
self.var_name = var_name
|
||||
self.condition = condition
|
||||
|
||||
|
||||
class SequentialAccessControlCondition(MultiConditionAccessControl):
|
||||
"""
|
||||
A series of conditions that are evaluated in a specific order, where the result of one
|
||||
condition can be used in subsequent conditions.
|
||||
|
||||
SequentialCondition grammar:
|
||||
CONDITION_VARIABLE = {
|
||||
"varName": STR,
|
||||
"condition": {
|
||||
CONDITION
|
||||
}
|
||||
}
|
||||
|
||||
SEQUENTIAL_CONDITION = {
|
||||
"name": ... (Optional)
|
||||
"conditionType": "sequential",
|
||||
"conditionVariables": [CONDITION_VARIABLE*]
|
||||
}
|
||||
"""
|
||||
|
||||
CONDITION_TYPE = ConditionType.SEQUENTIAL.value
|
||||
|
||||
@classmethod
|
||||
def _validate_condition_variables(
|
||||
cls,
|
||||
condition_variables: List[Union[Dict, ConditionVariable]],
|
||||
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
|
||||
):
|
||||
num_condition_variables = len(condition_variables)
|
||||
if num_condition_variables < 2:
|
||||
raise exception_class("At least two conditions must be specified")
|
||||
|
||||
if num_condition_variables > cls.MAX_NUM_CONDITIONS:
|
||||
raise exception_class(
|
||||
f"Maximum of {cls.MAX_NUM_CONDITIONS} conditions are allowed"
|
||||
)
|
||||
|
||||
class Schema(AccessControlCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
|
||||
)
|
||||
condition_variables = fields.List(
|
||||
fields.Nested(ConditionVariable.Schema(), required=True)
|
||||
)
|
||||
|
||||
# maintain field declaration ordering
|
||||
class Meta:
|
||||
ordered = True
|
||||
|
||||
@validates_schema
|
||||
def validate_condition_variables(self, data, **kwargs):
|
||||
condition_variables = data["condition_variables"]
|
||||
SequentialAccessControlCondition._validate_condition_variables(
|
||||
condition_variables, ValidationError
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return SequentialAccessControlCondition(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_variables: List[ConditionVariable],
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
self._validate_condition_variables(
|
||||
condition_variables=condition_variables, exception_class=InvalidCondition
|
||||
)
|
||||
self.condition_variables = condition_variables
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
|
||||
def __repr__(self):
|
||||
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
|
||||
return r
|
||||
|
||||
# TODO - think about not dereferencing context but using a dict;
|
||||
# may allows more freedom for params
|
||||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
) -> Tuple[bool, Any]:
|
||||
values = []
|
||||
latest_success = False
|
||||
inner_context = dict(context) # don't modify passed in context - use a copy
|
||||
# resolve variables
|
||||
for condition_variable in self.condition_variables:
|
||||
latest_success, result = condition_variable.condition.verify(
|
||||
providers=providers, **inner_context
|
||||
)
|
||||
values.append(result)
|
||||
if not latest_success:
|
||||
# short circuit due to failed condition
|
||||
break
|
||||
|
||||
inner_context[f":{condition_variable.var_name}"] = result
|
||||
|
||||
return latest_success, values
|
||||
|
||||
@property
|
||||
def conditions(self):
|
||||
return [
|
||||
condition_variable.condition
|
||||
for condition_variable in self.condition_variables
|
||||
]
|
||||
|
||||
|
||||
class ReturnValueTest:
|
||||
class InvalidExpression(ValueError):
|
||||
pass
|
||||
|
@ -357,16 +491,6 @@ class ConditionLingo(_Serializable):
|
|||
"""
|
||||
|
||||
def __init__(self, condition: AccessControlCondition, version: str = VERSION):
|
||||
"""
|
||||
CONDITION = BASE_CONDITION | COMPOUND_CONDITION
|
||||
BASE_CONDITION = {
|
||||
// ..
|
||||
}
|
||||
COMPOUND_CONDITION = {
|
||||
"operator": OPERATOR,
|
||||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
self.condition = condition
|
||||
self.check_version_compatibility(version)
|
||||
self.version = version
|
||||
|
@ -412,8 +536,7 @@ class ConditionLingo(_Serializable):
|
|||
cls, condition: ConditionDict, version: int = None
|
||||
) -> Type[AccessControlCondition]:
|
||||
"""
|
||||
TODO: This feels like a jenky way to resolve data types from JSON blobs, but it works.
|
||||
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
|
||||
Inspects a given block of JSON and attempts to resolve it's intended datatype within the
|
||||
conditions expression framework.
|
||||
"""
|
||||
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
|
||||
|
@ -429,12 +552,13 @@ class ConditionLingo(_Serializable):
|
|||
RPCCondition,
|
||||
CompoundAccessControlCondition,
|
||||
JsonApiCondition,
|
||||
SequentialAccessControlCondition,
|
||||
):
|
||||
if condition.CONDITION_TYPE == condition_type:
|
||||
return condition
|
||||
|
||||
raise InvalidConditionLingo(
|
||||
f"Cannot resolve condition lingo with condition type {condition_type}"
|
||||
f"Cannot resolve condition lingo, {condition}, with condition type {condition_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -443,3 +567,50 @@ class ConditionLingo(_Serializable):
|
|||
raise InvalidConditionLingo(
|
||||
f"Version provided, {version}, is incompatible with current version {cls.VERSION}"
|
||||
)
|
||||
|
||||
|
||||
class ExecutionCallAccessControlCondition(AccessControlCondition):
|
||||
"""
|
||||
Conditions that utilize underlying ExecutionCall objects.
|
||||
"""
|
||||
|
||||
class Schema(AccessControlCondition.Schema):
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
name: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.return_value_test = return_value_test
|
||||
try:
|
||||
self.execution_call = self._create_execution_call(*args, **kwargs)
|
||||
except ValueError as e:
|
||||
raise InvalidCondition(str(e))
|
||||
|
||||
super().__init__(condition_type=condition_type, name=name)
|
||||
|
||||
@abstractmethod
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
"""
|
||||
Returns the execution call that the condition executes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def verify(self, *args, **kwargs) -> Tuple[bool, Any]:
|
||||
"""
|
||||
Verifies the condition is met by performing execution call and
|
||||
evaluating the return value test.
|
||||
"""
|
||||
result = self.execution_call.execute(*args, **kwargs)
|
||||
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**kwargs
|
||||
)
|
||||
eval_result = resolved_return_value_test.eval(result) # test
|
||||
return eval_result, result
|
||||
|
|
|
@ -6,13 +6,15 @@ from jsonpath_ng.ext import parse
|
|||
from marshmallow import fields, post_load, validate
|
||||
from marshmallow.fields import Field, Url
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ConditionEvaluationFailed,
|
||||
InvalidCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionType,
|
||||
ExecutionCallAccessControlCondition,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
||||
|
@ -32,58 +34,29 @@ class JSONPathField(Field):
|
|||
return value
|
||||
|
||||
|
||||
class JsonApiCondition(AccessControlCondition):
|
||||
"""
|
||||
A JSON API condition is a condition that can be evaluated by reading from a JSON
|
||||
HTTPS endpoint. The response must return an HTTP 200 with valid JSON in the response body.
|
||||
The response will be deserialized as JSON and parsed using jsonpath.
|
||||
"""
|
||||
|
||||
CONDITION_TYPE = ConditionType.JSONAPI.value
|
||||
LOGGER = Logger("nucypher.policy.conditions.JsonApiCondition")
|
||||
class JsonApiCall(ExecutionCall):
|
||||
TIMEOUT = 5 # seconds
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
|
||||
name = fields.Str(required=False)
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
|
||||
)
|
||||
parameters = fields.Dict(required=False, allow_none=True)
|
||||
endpoint = Url(required=True, relative=False, schemes=["https"])
|
||||
query = JSONPathField(required=False, allow_none=True)
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return JsonApiCondition(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
query: Optional[str] = None,
|
||||
parameters: Optional[dict] = None,
|
||||
condition_type: str = ConditionType.JSONAPI.value,
|
||||
query: Optional[str] = None,
|
||||
):
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
|
||||
self.condition_type = condition_type
|
||||
self.endpoint = endpoint
|
||||
self.parameters = parameters or {}
|
||||
self.query = query
|
||||
self.return_value_test = return_value_test
|
||||
|
||||
self.timeout = self.TIMEOUT
|
||||
self.logger = self.LOGGER
|
||||
self.logger = Logger(__name__)
|
||||
|
||||
super().__init__()
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
response = self._fetch()
|
||||
data = self._deserialize_response(response)
|
||||
result = self._query_response(data)
|
||||
return result
|
||||
|
||||
def fetch(self) -> requests.Response:
|
||||
def _fetch(self) -> requests.Response:
|
||||
"""Fetches data from the endpoint."""
|
||||
try:
|
||||
response = requests.get(
|
||||
|
@ -111,7 +84,7 @@ class JsonApiCondition(AccessControlCondition):
|
|||
|
||||
return response
|
||||
|
||||
def deserialize_response(self, response: requests.Response) -> Any:
|
||||
def _deserialize_response(self, response: requests.Response) -> Any:
|
||||
"""Deserializes the JSON response from the endpoint."""
|
||||
try:
|
||||
data = response.json()
|
||||
|
@ -122,7 +95,7 @@ class JsonApiCondition(AccessControlCondition):
|
|||
)
|
||||
return data
|
||||
|
||||
def query_response(self, data: Any) -> Any:
|
||||
def _query_response(self, data: Any) -> Any:
|
||||
|
||||
if not self.query:
|
||||
return data # primitive value
|
||||
|
@ -148,6 +121,55 @@ class JsonApiCondition(AccessControlCondition):
|
|||
|
||||
return result
|
||||
|
||||
|
||||
class JsonApiCondition(ExecutionCallAccessControlCondition):
|
||||
"""
|
||||
A JSON API condition is a condition that can be evaluated by reading from a JSON
|
||||
HTTPS endpoint. The response must return an HTTP 200 with valid JSON in the response body.
|
||||
The response will be deserialized as JSON and parsed using jsonpath.
|
||||
"""
|
||||
|
||||
CONDITION_TYPE = ConditionType.JSONAPI.value
|
||||
|
||||
class Schema(ExecutionCallAccessControlCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
|
||||
)
|
||||
endpoint = Url(required=True, relative=False, schemes=["https"])
|
||||
parameters = fields.Dict(required=False, allow_none=True)
|
||||
query = JSONPathField(required=False, allow_none=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return JsonApiCondition(**data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition_type: str = ConditionType.JSONAPI.value,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(condition_type=condition_type, *args, **kwargs)
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return JsonApiCall(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def endpoint(self):
|
||||
return self.execution_call.endpoint
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
return self.execution_call.query
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
return self.execution_call.parameters
|
||||
|
||||
@property
|
||||
def timeout(self):
|
||||
return self.execution_call.timeout
|
||||
|
||||
@staticmethod
|
||||
def _process_result_for_eval(result: Any):
|
||||
# strings that are not already quoted will cause a problem for literal_eval
|
||||
|
@ -170,14 +192,11 @@ class JsonApiCondition(AccessControlCondition):
|
|||
and evaluating the return value test with the result. Parses the endpoint's JSON response using
|
||||
JSONPath.
|
||||
"""
|
||||
response = self.fetch()
|
||||
data = self.deserialize_response(response)
|
||||
result = self.query_response(data)
|
||||
result = self.execution_call.execute(**context)
|
||||
result_for_eval = self._process_result_for_eval(result)
|
||||
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**context
|
||||
)
|
||||
|
||||
result_for_eval = self._process_result_for_eval(result)
|
||||
eval_result = resolved_return_value_test.eval(result_for_eval) # test
|
||||
return eval_result, result
|
||||
|
|
|
@ -1,33 +1,53 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from marshmallow import fields, post_load, validate
|
||||
from marshmallow.validate import Equal, OneOf
|
||||
from marshmallow.validate import Equal
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.policy.conditions.evm import _CONDITION_CHAINS, RPCCondition
|
||||
from nucypher.policy.conditions.base import ExecutionCall
|
||||
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
from nucypher.policy.conditions.lingo import ConditionType
|
||||
|
||||
|
||||
class TimeRPCCall(RPCCall):
|
||||
METHOD = "blocktime"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain: int,
|
||||
method: str = METHOD,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
if parameters:
|
||||
raise ValueError(f"{self.METHOD} does not take any parameters")
|
||||
|
||||
super().__init__(chain=chain, method=method, parameters=parameters)
|
||||
|
||||
def _validate_method(self, method):
|
||||
if method != self.METHOD:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
|
||||
)
|
||||
return method
|
||||
|
||||
def _execute(self, w3: Web3, resolved_parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
# TODO may need to rethink as part of #3051 (multicall work).
|
||||
latest_block = w3.eth.get_block("latest")
|
||||
return latest_block.timestamp
|
||||
|
||||
|
||||
class TimeCondition(RPCCondition):
|
||||
METHOD = "blocktime"
|
||||
CONDITION_TYPE = ConditionType.TIME.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
class Schema(RPCCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.TIME.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(
|
||||
dump_default="blocktime", required=True, validate=Equal("blocktime")
|
||||
)
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -39,28 +59,21 @@ class TimeCondition(RPCCondition):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
return_value_test: ReturnValueTest,
|
||||
chain: int,
|
||||
method: str = METHOD,
|
||||
method: str = TimeRPCCall.METHOD,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if method != self.METHOD:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
|
||||
)
|
||||
|
||||
# call to super must be at the end for proper validation
|
||||
super().__init__(
|
||||
chain=chain,
|
||||
method=method,
|
||||
return_value_test=return_value_test,
|
||||
name=name,
|
||||
condition_type=condition_type,
|
||||
method=method,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _validate_method(self, method):
|
||||
return method
|
||||
def _create_execution_call(self, *args, **kwargs) -> ExecutionCall:
|
||||
return TimeRPCCall(*args, **kwargs)
|
||||
|
||||
def _validate_expected_return_type(self):
|
||||
comparator_value = self.return_value_test.value
|
||||
|
@ -72,9 +85,3 @@ class TimeCondition(RPCCondition):
|
|||
@property
|
||||
def timestamp(self):
|
||||
return self.return_value_test.value
|
||||
|
||||
def _execute_call(self, parameters: List[Any]) -> Any:
|
||||
"""Execute onchain read and return result."""
|
||||
# TODO may need to rethink as part of #3051 (multicall work).
|
||||
latest_block = self.w3.eth.get_block("latest")
|
||||
return latest_block.timestamp
|
||||
|
|
|
@ -34,16 +34,20 @@ class ReturnValueTestDict(TypedDict):
|
|||
key: NotRequired[Union[str, int]]
|
||||
|
||||
|
||||
# Conditions
|
||||
class _AccessControlCondition(TypedDict):
|
||||
name: NotRequired[str]
|
||||
|
||||
|
||||
class RPCConditionDict(_AccessControlCondition):
|
||||
conditionType: str
|
||||
|
||||
|
||||
class BaseExecConditionDict(_AccessControlCondition):
|
||||
returnValueTest: ReturnValueTestDict
|
||||
|
||||
|
||||
class RPCConditionDict(BaseExecConditionDict):
|
||||
chain: int
|
||||
method: str
|
||||
parameters: NotRequired[List[Any]]
|
||||
returnValueTest: ReturnValueTestDict
|
||||
|
||||
|
||||
class TimeConditionDict(RPCConditionDict):
|
||||
|
@ -56,17 +60,43 @@ class ContractConditionDict(RPCConditionDict):
|
|||
functionAbi: NotRequired[ABIFunction]
|
||||
|
||||
|
||||
class JsonApiConditionDict(BaseExecConditionDict):
|
||||
endpoint: str
|
||||
query: NotRequired[str]
|
||||
parameters: NotRequired[Dict]
|
||||
|
||||
#
|
||||
# CompoundCondition represents:
|
||||
# {
|
||||
# "operator": ["and" | "or"]
|
||||
# "operands": List[AccessControlCondition | CompoundCondition]
|
||||
# "operator": ["and" | "or" | "not"]
|
||||
# "operands": List[AccessControlCondition]
|
||||
# }
|
||||
#
|
||||
class CompoundConditionDict(_AccessControlCondition):
|
||||
operator: Literal["and", "or", "not"]
|
||||
operands: List["ConditionDict"]
|
||||
|
||||
|
||||
#
|
||||
class CompoundConditionDict(TypedDict):
|
||||
conditionType: str
|
||||
operator: Literal["and", "or"]
|
||||
operands: List["Lingo"]
|
||||
# ConditionVariable represents:
|
||||
# {
|
||||
# varName: str
|
||||
# condition: AccessControlCondition
|
||||
# }
|
||||
#
|
||||
class ConditionVariableDict(TypedDict):
|
||||
varName: str
|
||||
condition: "ConditionDict"
|
||||
|
||||
|
||||
#
|
||||
# SequentialCondition represents:
|
||||
# {
|
||||
# "conditionVariables": List[ConditionVariable]
|
||||
# }
|
||||
#
|
||||
class SequentialConditionDict(_AccessControlCondition):
|
||||
conditionVariables = List[ConditionVariableDict]
|
||||
|
||||
|
||||
#
|
||||
|
@ -75,8 +105,15 @@ class CompoundConditionDict(TypedDict):
|
|||
# - RPCCondition
|
||||
# - ContractCondition
|
||||
# - CompoundConditionDict
|
||||
# - JsonApiConditionDict
|
||||
# - SequentialConditionDict
|
||||
ConditionDict = Union[
|
||||
TimeConditionDict, RPCConditionDict, ContractConditionDict, CompoundConditionDict
|
||||
TimeConditionDict,
|
||||
RPCConditionDict,
|
||||
ContractConditionDict,
|
||||
CompoundConditionDict,
|
||||
JsonApiConditionDict,
|
||||
SequentialConditionDict,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -155,7 +155,7 @@ def _align_comparator_value_with_abi(
|
|||
)
|
||||
|
||||
|
||||
def _validate_condition_function_abi(function_abi: Dict, method_name: str) -> None:
|
||||
def _validate_function_abi(function_abi: Dict, method_name: str) -> None:
|
||||
"""validates a dictionary as valid for use as a condition function ABI"""
|
||||
abi = ABIFunction(function_abi)
|
||||
|
||||
|
@ -171,7 +171,7 @@ def _validate_condition_function_abi(function_abi: Dict, method_name: str) -> No
|
|||
raise ValueError(f"Invalid ABI stateMutability {abi}")
|
||||
|
||||
|
||||
def _validate_condition_abi(
|
||||
def _validate_contract_call_abi(
|
||||
standard_contract_type: str,
|
||||
function_abi: Dict,
|
||||
method_name: str,
|
||||
|
@ -181,4 +181,4 @@ def _validate_condition_abi(
|
|||
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
|
||||
)
|
||||
if function_abi:
|
||||
_validate_condition_function_abi(function_abi, method_name=method_name)
|
||||
_validate_function_abi(function_abi, method_name=method_name)
|
||||
|
|
|
@ -15,9 +15,11 @@ from nucypher.characters.lawful import Enrico, Ursula
|
|||
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionLingo,
|
||||
ConditionVariable,
|
||||
NotCompoundCondition,
|
||||
OrCompoundCondition,
|
||||
ReturnValueTest,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.time import TimeCondition
|
||||
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
|
||||
|
@ -93,7 +95,14 @@ def condition(test_registry):
|
|||
)
|
||||
|
||||
not_not_condition = NotCompoundCondition(
|
||||
operand=NotCompoundCondition(operand=and_condition)
|
||||
operand=NotCompoundCondition(operand=rpc_condition)
|
||||
)
|
||||
|
||||
sequential_condition = SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
ConditionVariable("rpc", rpc_condition),
|
||||
ConditionVariable("contract", contract_condition),
|
||||
]
|
||||
)
|
||||
|
||||
conditions = [
|
||||
|
@ -103,6 +112,7 @@ def condition(test_registry):
|
|||
or_condition,
|
||||
and_condition,
|
||||
not_not_condition,
|
||||
sequential_condition,
|
||||
]
|
||||
|
||||
condition_to_use = random.choice(conditions)
|
||||
|
|
|
@ -8,7 +8,6 @@ from nucypher.blockchain.eth.agents import (
|
|||
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
|
||||
from nucypher.policy.conditions.evm import ContractCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
AndCompoundCondition,
|
||||
ConditionLingo,
|
||||
OrCompoundCondition,
|
||||
ReturnValueTest,
|
||||
|
@ -21,7 +20,6 @@ def condition_providers(testerchain):
|
|||
providers = {testerchain.client.chain_id: {testerchain.provider}}
|
||||
return providers
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def compound_lingo(
|
||||
erc721_evm_condition_balanceof,
|
||||
|
@ -35,12 +33,8 @@ def compound_lingo(
|
|||
operands=[
|
||||
erc721_evm_condition_balanceof,
|
||||
time_condition,
|
||||
AndCompoundCondition(
|
||||
operands=[
|
||||
rpc_condition,
|
||||
erc20_evm_condition_balanceof,
|
||||
]
|
||||
),
|
||||
rpc_condition,
|
||||
erc20_evm_condition_balanceof,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -22,7 +22,6 @@ from nucypher.policy.conditions.evm import (
|
|||
RPCCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
|
@ -84,10 +83,10 @@ def test_rpc_condition_evaluation_invalid_provider_for_chain(
|
|||
):
|
||||
context = {USER_ADDRESS_CONTEXT: {"address": accounts.unassigned_accounts[0]}}
|
||||
new_chain = 23
|
||||
rpc_condition.chain = new_chain
|
||||
rpc_condition.execution_call.chain = new_chain
|
||||
condition_providers = {new_chain: {testerchain.provider}}
|
||||
with pytest.raises(
|
||||
InvalidCondition, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
NoConnectionToChain, match=f"can only be evaluated on chain ID {new_chain}"
|
||||
):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -156,10 +155,8 @@ def test_rpc_condition_evaluation_multiple_providers_no_valid_fallback(
|
|||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition, "_check_chain_id", return_value=None
|
||||
) # skip chain check
|
||||
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
|
||||
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
with pytest.raises(RPCExecutionFailed):
|
||||
_ = rpc_condition.verify(providers=condition_providers, **context)
|
||||
|
||||
|
@ -186,9 +183,8 @@ def test_rpc_condition_evaluation_multiple_providers_valid_fallback(
|
|||
}
|
||||
|
||||
mocker.patch.object(
|
||||
rpc_condition, "_check_chain_id", return_value=None
|
||||
) # skip chain check
|
||||
mocker.patch.object(rpc_condition, "_configure_w3", my_configure_w3)
|
||||
rpc_condition.execution_call, "_configure_provider", my_configure_w3
|
||||
)
|
||||
|
||||
condition_result, call_result = rpc_condition.verify(
|
||||
providers=condition_providers, **context
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
|
||||
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
CompoundAccessControlCondition,
|
||||
ConditionLingo,
|
||||
ConditionType,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
|
||||
from nucypher.utilities.logging import GlobalLoggerSettings
|
||||
from tests.utils.policy import make_message_kits
|
||||
|
||||
|
@ -14,22 +21,28 @@ def make_multichain_evm_conditions(bob, chain_ids):
|
|||
"""This is a helper function to make a set of conditions that are valid on multiple chains."""
|
||||
operands = list()
|
||||
for chain_id in chain_ids:
|
||||
operand = [
|
||||
{
|
||||
"conditionType": ConditionType.TIME.value,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": chain_id,
|
||||
},
|
||||
{
|
||||
"conditionType": ConditionType.RPC.value,
|
||||
"chain": chain_id,
|
||||
"method": "eth_getBalance",
|
||||
"parameters": [bob.checksum_address, "latest"],
|
||||
"returnValueTest": {"comparator": ">=", "value": 10000000000000},
|
||||
},
|
||||
]
|
||||
operands.extend(operand)
|
||||
compound_and_condition = CompoundAccessControlCondition(
|
||||
operator="and",
|
||||
operands=[
|
||||
TimeCondition(
|
||||
chain=chain_id,
|
||||
return_value_test=ReturnValueTest(
|
||||
comparator=">",
|
||||
value=0,
|
||||
),
|
||||
),
|
||||
RPCCondition(
|
||||
chain=chain_id,
|
||||
method="eth_getBalance",
|
||||
parameters=[bob.checksum_address, "latest"],
|
||||
return_value_test=ReturnValueTest(
|
||||
comparator=">=",
|
||||
value=10000000000000,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
operands.append(compound_and_condition.to_dict())
|
||||
|
||||
_conditions = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
|
@ -69,7 +82,7 @@ def test_single_retrieve_with_multichain_conditions(
|
|||
|
||||
|
||||
def test_single_decryption_request_with_faulty_rpc_endpoint(
|
||||
enacted_policy, bob, multichain_ursulas, conditions, mock_rpc_condition
|
||||
monkeymodule, testerchain, enacted_policy, bob, multichain_ursulas, conditions
|
||||
):
|
||||
bob.remember_node(multichain_ursulas[0])
|
||||
bob.start_learning_loop()
|
||||
|
@ -80,30 +93,43 @@ def test_single_decryption_request_with_faulty_rpc_endpoint(
|
|||
alice_verifying_key=enacted_policy.publisher_verifying_key,
|
||||
)
|
||||
|
||||
def _mock_configure_provider(*args, **kwargs):
|
||||
rpc_call_type = args[0]
|
||||
if isinstance(rpc_call_type, TimeRPCCall):
|
||||
# time condition call - only RPCCall is made faulty
|
||||
return testerchain.w3
|
||||
|
||||
# rpc condition call
|
||||
provider = args[1]
|
||||
w3 = Web3(provider)
|
||||
return w3
|
||||
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", _mock_configure_provider)
|
||||
|
||||
calls = defaultdict(int)
|
||||
original_execute_call = RPCCondition._execute_call
|
||||
original_execute_call = RPCCall._execute
|
||||
|
||||
def faulty_execute_call(*args, **kwargs):
|
||||
"""Intercept the call to the RPC endpoint and raise an exception on the second call."""
|
||||
nonlocal calls
|
||||
rpc_call = args[0]
|
||||
calls[rpc_call.chain] += 1
|
||||
if (
|
||||
calls[rpc_call.chain] == 2
|
||||
and "tester://multichain.0" in rpc_call.provider.endpoint_uri
|
||||
):
|
||||
rpc_call_object = args[0]
|
||||
resolved_parameters = args[2]
|
||||
calls[rpc_call_object.chain] += 1
|
||||
if calls[rpc_call_object.chain] % 2 == 0:
|
||||
# simulate a network error
|
||||
raise ConnectionError("Something went wrong with the network")
|
||||
elif calls[rpc_call.chain] == 3:
|
||||
# check the provider is the fallback
|
||||
this_uri = rpc_call.provider.endpoint_uri
|
||||
assert "fallback" in this_uri
|
||||
return original_execute_call(*args, **kwargs)
|
||||
|
||||
RPCCondition._execute_call = faulty_execute_call
|
||||
# replace w3 object with fake provider, with proper w3 object for actual execution
|
||||
return original_execute_call(
|
||||
rpc_call_object, testerchain.w3, resolved_parameters
|
||||
)
|
||||
|
||||
RPCCall._execute = faulty_execute_call
|
||||
|
||||
cleartexts = bob.retrieve_and_decrypt(
|
||||
message_kits=message_kits,
|
||||
**policy_info_kwargs,
|
||||
)
|
||||
assert cleartexts == messages
|
||||
RPCCondition._execute_call = original_execute_call
|
||||
|
||||
RPCCall._execute = original_execute_call
|
||||
|
|
|
@ -14,7 +14,7 @@ from nucypher.blockchain.eth.agents import (
|
|||
)
|
||||
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry, RegistrySourceManager
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.evm import RPCCall
|
||||
from nucypher.utilities.logging import Logger
|
||||
from tests.constants import (
|
||||
BONUS_TOKENS_FOR_TESTS,
|
||||
|
@ -430,16 +430,11 @@ def taco_child_application_agent(testerchain, test_registry):
|
|||
#
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_rpc_condition(module_mocker, testerchain, monkeymodule):
|
||||
def configure_mock(condition, provider, *args, **kwargs):
|
||||
condition.provider = provider
|
||||
def mock_rpc_condition(testerchain, monkeymodule):
|
||||
def configure_mock(*args, **kwargs):
|
||||
return testerchain.w3
|
||||
|
||||
monkeymodule.setattr(RPCCondition, "_configure_w3", configure_mock)
|
||||
configure_spy = module_mocker.spy(RPCCondition, "_configure_w3")
|
||||
|
||||
chain_id_check_mock = module_mocker.patch.object(RPCCondition, "_check_chain_id")
|
||||
return configure_spy, chain_id_check_mock
|
||||
monkeymodule.setattr(RPCCall, "_configure_provider", configure_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
import pytest
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition
|
||||
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
|
||||
from nucypher.policy.conditions.evm import ContractCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
|
@ -94,3 +95,8 @@ def erc721_evm_condition(test_registry):
|
|||
]
|
||||
)
|
||||
return condition
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_skip_schema_validation(mocker):
|
||||
mocker.patch.object(AccessControlCondition.Schema, "validate", return_value=None)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import random
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
@ -8,8 +9,10 @@ from nucypher.policy.conditions.lingo import (
|
|||
AndCompoundCondition,
|
||||
CompoundAccessControlCondition,
|
||||
ConditionType,
|
||||
ConditionVariable,
|
||||
NotCompoundCondition,
|
||||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
|
||||
|
||||
|
@ -63,7 +66,10 @@ def test_invalid_compound_condition(time_condition, rpc_condition):
|
|||
|
||||
# no operands
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = CompoundAccessControlCondition(operator=operator, operands=[])
|
||||
_ = CompoundAccessControlCondition(
|
||||
operator=random.choice(CompoundAccessControlCondition.OPERATORS),
|
||||
operands=[],
|
||||
)
|
||||
|
||||
# > 1 operand for not operator
|
||||
with pytest.raises(InvalidCondition):
|
||||
|
@ -86,6 +92,21 @@ def test_invalid_compound_condition(time_condition, rpc_condition):
|
|||
operands=[rpc_condition],
|
||||
)
|
||||
|
||||
# exceeds max operands
|
||||
operands = list()
|
||||
for i in range(CompoundAccessControlCondition.MAX_NUM_CONDITIONS + 1):
|
||||
operands.append(rpc_condition)
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = CompoundAccessControlCondition(
|
||||
operator=CompoundAccessControlCondition.OR_OPERATOR,
|
||||
operands=operands,
|
||||
)
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = CompoundAccessControlCondition(
|
||||
operator=CompoundAccessControlCondition.AND_OPERATOR,
|
||||
operands=operands,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("operator", CompoundAccessControlCondition.OPERATORS)
|
||||
def test_compound_condition_schema_validation(operator, time_condition, rpc_condition):
|
||||
|
@ -131,7 +152,8 @@ def test_compound_condition_schema_validation(operator, time_condition, rpc_cond
|
|||
CompoundAccessControlCondition.validate(compound_condition_dict)
|
||||
|
||||
|
||||
def test_and_condition_and_short_circuit(mock_conditions):
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_and_condition_and_short_circuit(mocker, mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
and_condition = AndCompoundCondition(
|
||||
|
@ -144,14 +166,14 @@ def test_and_condition_and_short_circuit(mock_conditions):
|
|||
)
|
||||
|
||||
# ensure that all conditions evaluated when all return True
|
||||
result, value = and_condition.verify()
|
||||
result, value = and_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 4, "all conditions evaluated"
|
||||
assert value == [1, 2, 3, 4]
|
||||
|
||||
# ensure that short circuit happens when 1st condition is false
|
||||
condition_1.verify.return_value = (False, 1)
|
||||
result, value = and_condition.verify()
|
||||
result, value = and_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 1, "only one condition evaluated"
|
||||
assert value == [1]
|
||||
|
@ -159,12 +181,13 @@ def test_and_condition_and_short_circuit(mock_conditions):
|
|||
# short circuit occurs for 3rd entry
|
||||
condition_1.verify.return_value = (True, 1)
|
||||
condition_3.verify.return_value = (False, 3)
|
||||
result, value = and_condition.verify()
|
||||
result, value = and_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 3, "3-of-4 conditions evaluated"
|
||||
assert value == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_or_condition_and_short_circuit(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
|
@ -179,7 +202,7 @@ def test_or_condition_and_short_circuit(mock_conditions):
|
|||
|
||||
# ensure that only first condition evaluated when first is True
|
||||
condition_1.verify.return_value = (True, 1) # short circuit here
|
||||
result, value = or_condition.verify()
|
||||
result, value = or_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 1, "only first condition needs to be evaluated"
|
||||
assert value == [1]
|
||||
|
@ -189,7 +212,7 @@ def test_or_condition_and_short_circuit(mock_conditions):
|
|||
condition_2.verify.return_value = (False, 2)
|
||||
condition_3.verify.return_value = (True, 3) # short circuit here
|
||||
|
||||
result, value = or_condition.verify()
|
||||
result, value = or_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 3, "third condition causes short circuit"
|
||||
assert value == [1, 2, 3]
|
||||
|
@ -200,12 +223,13 @@ def test_or_condition_and_short_circuit(mock_conditions):
|
|||
condition_3.verify.return_value = (False, 3)
|
||||
condition_4.verify.return_value = (False, 4)
|
||||
|
||||
result, value = or_condition.verify()
|
||||
result, value = or_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 4, "all conditions evaluated"
|
||||
assert value == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_compound_condition(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
|
@ -223,7 +247,7 @@ def test_compound_condition(mock_conditions):
|
|||
)
|
||||
|
||||
# all conditions are True
|
||||
result, value = compound_condition.verify()
|
||||
result, value = compound_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert value == [[1], 4]
|
||||
|
@ -232,7 +256,7 @@ def test_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (False, 1)
|
||||
condition_2.verify.return_value = (False, 2)
|
||||
condition_3.verify.return_value = (False, 3)
|
||||
result, value = compound_condition.verify()
|
||||
result, value = compound_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 1, "or_condition"
|
||||
assert value == [
|
||||
|
@ -243,7 +267,7 @@ def test_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (True, 1)
|
||||
condition_4.verify.return_value = (False, 4)
|
||||
|
||||
result, value = compound_condition.verify()
|
||||
result, value = compound_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert value == [
|
||||
|
@ -253,7 +277,7 @@ def test_compound_condition(mock_conditions):
|
|||
|
||||
# condition_4 is now true
|
||||
condition_4.verify.return_value = (True, 4)
|
||||
result, value = compound_condition.verify()
|
||||
result, value = compound_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert value == [
|
||||
|
@ -262,6 +286,57 @@ def test_compound_condition(mock_conditions):
|
|||
] # or-condition short-circuited because condition_1 was True
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_compound_condition_too_many_nested_levels(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
_ = AndCompoundCondition(
|
||||
operands=[
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
AndCompoundCondition(
|
||||
operands=[
|
||||
condition_2,
|
||||
condition_3,
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_sequential_condition_too_many_nested_levels(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
_ = AndCompoundCondition(
|
||||
operands=[
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
ConditionVariable("var2", condition_2),
|
||||
ConditionVariable("var3", condition_3),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_compound_condition(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
|
@ -270,12 +345,8 @@ def test_nested_compound_condition(mock_conditions):
|
|||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
AndCompoundCondition(
|
||||
operands=[
|
||||
condition_2,
|
||||
condition_3,
|
||||
]
|
||||
),
|
||||
condition_2,
|
||||
condition_3,
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
|
@ -283,30 +354,43 @@ def test_nested_compound_condition(mock_conditions):
|
|||
)
|
||||
|
||||
# all conditions are True
|
||||
result, value = nested_compound_condition.verify()
|
||||
result, value = nested_compound_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert len(value) == 2, "or_condition (condition_1) and condition_4"
|
||||
assert value == [[1], 4] # or short-circuited since condition_1 is True
|
||||
|
||||
# set condition_1 to False so nested and-condition must be evaluated
|
||||
# set condition_1 to False so condition_2 must be evaluated
|
||||
condition_1.verify.return_value = (False, 1)
|
||||
|
||||
result, value = nested_compound_condition.verify()
|
||||
result, value = nested_compound_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert len(value) == 2, "or_condition (condition_2) and condition_4"
|
||||
assert value == [
|
||||
[1, [2, 3]],
|
||||
[1, 2],
|
||||
4,
|
||||
] # nested and-condition was evaluated and evaluated to True
|
||||
] # or short-circuited since condition_2 is True
|
||||
|
||||
# set condition_3 to False so condition_3 must be evaluated
|
||||
condition_2.verify.return_value = (False, 2)
|
||||
|
||||
result, value = nested_compound_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert len(value) == 2, "or_condition (condition_3) and condition_4"
|
||||
assert value == [
|
||||
[1, 2, 3],
|
||||
4,
|
||||
] # or short-circuited since condition_3 is True
|
||||
|
||||
# set condition_4 to False so that overall result flips to False
|
||||
# (even though condition_3 is still True)
|
||||
condition_4.verify.return_value = (False, 4)
|
||||
result, value = nested_compound_condition.verify()
|
||||
result, value = nested_compound_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert len(value) == 2, "or_condition and condition_4"
|
||||
assert value == [[1, [2, 3]], 4]
|
||||
assert value == [[1, 2, 3], 4]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_not_compound_condition(mock_conditions):
|
||||
condition_1, condition_2, condition_3, condition_4 = mock_conditions
|
||||
|
||||
|
@ -316,12 +400,12 @@ def test_not_compound_condition(mock_conditions):
|
|||
# simple `not`
|
||||
#
|
||||
condition_1.verify.return_value = (True, 1)
|
||||
result, value = not_condition.verify()
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert value == 1
|
||||
|
||||
condition_1.verify.return_value = (False, 2)
|
||||
result, value = not_condition.verify()
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert value == 2
|
||||
|
||||
|
@ -342,8 +426,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
]
|
||||
)
|
||||
not_condition = NotCompoundCondition(operand=or_condition)
|
||||
or_result, or_value = or_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
or_result, or_value = or_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert result is (not or_result)
|
||||
assert value == or_value
|
||||
|
@ -352,8 +436,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (False, 1)
|
||||
condition_2.verify.return_value = (False, 2)
|
||||
condition_3.verify.return_value = (False, 3)
|
||||
or_result, or_value = or_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
or_result, or_value = or_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert result is (not or_result)
|
||||
assert value == or_value
|
||||
|
@ -362,8 +446,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (False, 1)
|
||||
condition_2.verify.return_value = (False, 2)
|
||||
condition_3.verify.return_value = (True, 3)
|
||||
or_result, or_value = or_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
or_result, or_value = or_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert result is (not or_result)
|
||||
assert value == or_value
|
||||
|
@ -386,8 +470,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
)
|
||||
not_condition = NotCompoundCondition(operand=and_condition)
|
||||
|
||||
and_result, and_value = and_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
and_result, and_value = and_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is False
|
||||
assert result is (not and_result)
|
||||
assert value == and_value
|
||||
|
@ -396,8 +480,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (False, 1)
|
||||
condition_2.verify.return_value = (False, 2)
|
||||
condition_3.verify.return_value = (False, 3)
|
||||
and_result, and_value = and_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
and_result, and_value = and_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert result is (not and_result)
|
||||
assert value == and_value
|
||||
|
@ -406,59 +490,8 @@ def test_not_compound_condition(mock_conditions):
|
|||
condition_1.verify.return_value = (False, 1)
|
||||
condition_2.verify.return_value = (True, 2)
|
||||
condition_3.verify.return_value = (False, 3)
|
||||
and_result, and_value = and_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
and_result, and_value = and_condition.verify(providers={})
|
||||
result, value = not_condition.verify(providers={})
|
||||
assert result is True
|
||||
assert result is (not and_result)
|
||||
assert value == and_value
|
||||
|
||||
#
|
||||
# Complex nested `or` and `and` (reused nested compound condition in previous test)
|
||||
#
|
||||
nested_compound_condition = AndCompoundCondition(
|
||||
operands=[
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
condition_1,
|
||||
AndCompoundCondition(
|
||||
operands=[
|
||||
condition_2,
|
||||
condition_3,
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
condition_4,
|
||||
]
|
||||
)
|
||||
|
||||
not_condition = NotCompoundCondition(operand=nested_compound_condition)
|
||||
|
||||
# reset all conditions to True
|
||||
condition_1.verify.return_value = (True, 1)
|
||||
condition_2.verify.return_value = (True, 2)
|
||||
condition_3.verify.return_value = (True, 3)
|
||||
condition_4.verify.return_value = (True, 4)
|
||||
|
||||
nested_result, nested_value = nested_compound_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
assert result is False
|
||||
assert result is (not nested_result)
|
||||
assert value == nested_value
|
||||
|
||||
# set condition_1 to False so nested and-condition must be evaluated
|
||||
condition_1.verify.return_value = (False, 1)
|
||||
|
||||
nested_result, nested_value = nested_compound_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
assert result is False
|
||||
assert result is (not nested_result)
|
||||
assert value == nested_value
|
||||
|
||||
# set condition_4 to False so that overall result flips to False, so `not` is now True
|
||||
condition_4.verify.return_value = (False, 4)
|
||||
nested_result, nested_value = nested_compound_condition.verify()
|
||||
result, value = not_condition.verify()
|
||||
assert result is True
|
||||
assert result is (not nested_result)
|
||||
assert value == nested_value
|
||||
|
|
|
@ -33,16 +33,15 @@ def lingo_with_compound_conditions(get_random_checksum_address):
|
|||
"operands": [
|
||||
{
|
||||
"conditionType": ConditionType.TIME.value,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
},
|
||||
{
|
||||
"conditionType": ConditionType.CONTRACT.value,
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "isPolicyActive",
|
||||
"parameters": [":hrac"],
|
||||
"returnValueTest": {"comparator": "==", "value": True},
|
||||
"contractAddress": get_random_checksum_address(),
|
||||
"functionAbi": {
|
||||
"type": "function",
|
||||
|
@ -59,38 +58,114 @@ def lingo_with_compound_conditions(get_random_checksum_address):
|
|||
{"name": "", "type": "bool", "internalType": "bool"}
|
||||
],
|
||||
},
|
||||
"returnValueTest": {"comparator": "==", "value": True},
|
||||
},
|
||||
# sequential condition
|
||||
{
|
||||
"conditionType": ConditionType.COMPOUND.value,
|
||||
"operator": "or",
|
||||
"operands": [
|
||||
"conditionType": ConditionType.SEQUENTIAL.value,
|
||||
"conditionVariables": [
|
||||
{
|
||||
"conditionType": ConditionType.TIME.value,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"varName": "timeValue",
|
||||
"condition": {
|
||||
# Time
|
||||
"conditionType": ConditionType.TIME.value,
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"returnValueTest": {
|
||||
"value": 0,
|
||||
"comparator": ">",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"conditionType": ConditionType.RPC.value,
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "eth_getBalance",
|
||||
"parameters": [get_random_checksum_address(), "latest"],
|
||||
"returnValueTest": {
|
||||
"comparator": ">=",
|
||||
"value": 10000000000000,
|
||||
"varName": "rpcValue",
|
||||
"condition": {
|
||||
# RPC
|
||||
"conditionType": ConditionType.RPC.value,
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "eth_getBalance",
|
||||
"parameters": [
|
||||
get_random_checksum_address(),
|
||||
"latest",
|
||||
],
|
||||
"returnValueTest": {
|
||||
"comparator": ">=",
|
||||
"value": 10000000000000,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"varName": "contractValue",
|
||||
"condition": {
|
||||
# Contract
|
||||
"conditionType": ConditionType.CONTRACT.value,
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "isPolicyActive",
|
||||
"parameters": [":hrac"],
|
||||
"contractAddress": get_random_checksum_address(),
|
||||
"functionAbi": {
|
||||
"type": "function",
|
||||
"name": "isPolicyActive",
|
||||
"stateMutability": "view",
|
||||
"inputs": [
|
||||
{
|
||||
"name": "_policyID",
|
||||
"type": "bytes16",
|
||||
"internalType": "bytes16",
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "",
|
||||
"type": "bool",
|
||||
"internalType": "bool",
|
||||
}
|
||||
],
|
||||
},
|
||||
"returnValueTest": {
|
||||
"comparator": "==",
|
||||
"value": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"varName": "jsonValue",
|
||||
"condition": {
|
||||
# JSON API
|
||||
"conditionType": ConditionType.JSONAPI.value,
|
||||
"endpoint": "https://api.example.com/data",
|
||||
"query": "$.store.book[0].price",
|
||||
"parameters": {
|
||||
"ids": "ethereum",
|
||||
"vs_currencies": "usd",
|
||||
},
|
||||
"returnValueTest": {
|
||||
"comparator": "==",
|
||||
"value": 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conditionType": ConditionType.RPC.value,
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "eth_getBalance",
|
||||
"parameters": [get_random_checksum_address(), "latest"],
|
||||
"returnValueTest": {
|
||||
"comparator": ">=",
|
||||
"value": 10000000000000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"conditionType": ConditionType.COMPOUND.value,
|
||||
"operator": "not",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": ConditionType.TIME.value,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
},
|
||||
],
|
||||
},
|
||||
|
@ -98,7 +173,6 @@ def lingo_with_compound_conditions(get_random_checksum_address):
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
def test_invalid_condition():
|
||||
# no version or condition
|
||||
data = dict()
|
||||
|
@ -127,6 +201,9 @@ def test_invalid_condition():
|
|||
with pytest.raises(InvalidConditionLingo):
|
||||
ConditionLingo.from_json(json.dumps(data))
|
||||
|
||||
|
||||
def test_invalid_compound_condition():
|
||||
|
||||
# invalid operator
|
||||
invalid_operator = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
|
@ -275,7 +352,7 @@ def test_condition_lingo_to_from_json(lingo_with_compound_conditions):
|
|||
assert clingo_from_json.to_dict() == lingo_with_compound_conditions
|
||||
|
||||
|
||||
def test_condition_lingo_repr(lingo_with_compound_conditions):
|
||||
def test_compound_condition_lingo_repr(lingo_with_compound_conditions):
|
||||
clingo = ConditionLingo.from_dict(lingo_with_compound_conditions)
|
||||
clingo_string = f"{clingo}"
|
||||
assert f"{clingo.__class__.__name__}" in clingo_string
|
||||
|
|
|
@ -11,7 +11,7 @@ from nucypher.policy.conditions.context import (
|
|||
_resolve_user_address,
|
||||
get_context_value,
|
||||
is_context_variable,
|
||||
resolve_any_context_variables,
|
||||
resolve_parameter_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ContextVariableVerificationFailed,
|
||||
|
@ -86,9 +86,8 @@ def test_resolve_any_context_variables():
|
|||
params, resolved_params = params_with_resolution
|
||||
value, resolved_value = value_with_resolution
|
||||
return_value_test = ReturnValueTest(comparator="==", value=value)
|
||||
resolved_parameters, resolved_return_value = resolve_any_context_variables(
|
||||
[params], return_value_test, **CONTEXT
|
||||
)
|
||||
resolved_parameters = resolve_parameter_context_variables([params], **CONTEXT)
|
||||
resolved_return_value = return_value_test.with_resolved_context(**CONTEXT)
|
||||
assert resolved_parameters == [resolved_params]
|
||||
assert resolved_return_value.comparator == return_value_test.comparator
|
||||
assert resolved_return_value.index == return_value_test.index
|
||||
|
|
|
@ -11,7 +11,7 @@ from hexbytes import HexBytes
|
|||
from marshmallow import post_load
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.evm import ContractCondition
|
||||
from nucypher.policy.conditions.evm import ContractCall, ContractCondition
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
|
@ -44,6 +44,17 @@ CONTRACT_CONDITION = {
|
|||
|
||||
|
||||
class FakeExecutionContractCondition(ContractCondition):
|
||||
class FakeRPCCall(ContractCall):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.execution_return_value = None
|
||||
|
||||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_return_value = value
|
||||
|
||||
def execute(self, providers: Dict, **context) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
class Schema(ContractCondition.Schema):
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -51,16 +62,12 @@ class FakeExecutionContractCondition(ContractCondition):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.execution_return_value = None
|
||||
|
||||
def _create_execution_call(self, *args, **kwargs) -> ContractCall:
|
||||
return self.FakeRPCCall(*args, **kwargs)
|
||||
|
||||
def set_execution_return_value(self, value: Any):
|
||||
self.execution_return_value = value
|
||||
|
||||
def _execute_call(self, parameters: List[Any]) -> Any:
|
||||
return self.execution_return_value
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
return
|
||||
self.execution_call.set_execution_return_value(value)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
|
|
@ -7,7 +7,6 @@ from marshmallow import ValidationError
|
|||
from nucypher.policy.conditions.exceptions import (
|
||||
ConditionEvaluationFailed,
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest
|
||||
from nucypher.policy.conditions.offchain import (
|
||||
|
@ -56,7 +55,7 @@ def test_json_api_condition_invalid_type():
|
|||
|
||||
|
||||
def test_https_enforcement():
|
||||
with pytest.raises(InvalidConditionLingo) as excinfo:
|
||||
with pytest.raises(InvalidCondition) as excinfo:
|
||||
JsonApiCondition(
|
||||
endpoint="http://api.example.com/data",
|
||||
query="$.store.book[0].price",
|
||||
|
@ -88,7 +87,7 @@ def test_json_api_condition_fetch(mocker):
|
|||
query="$.store.book[0].title",
|
||||
return_value_test=ReturnValueTest("==", "'Test Title'"),
|
||||
)
|
||||
response = condition.fetch()
|
||||
response = condition.execution_call._fetch()
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"store": {"book": [{"title": "Test Title"}]}}
|
||||
|
||||
|
@ -104,7 +103,7 @@ def test_json_api_condition_fetch_failure(mocker):
|
|||
return_value_test=ReturnValueTest("==", 1),
|
||||
)
|
||||
with pytest.raises(InvalidCondition) as excinfo:
|
||||
condition.fetch()
|
||||
condition.execution_call._fetch()
|
||||
assert "Failed to fetch endpoint" in str(excinfo.value)
|
||||
|
||||
|
||||
|
@ -224,7 +223,7 @@ def test_json_api_condition_from_lingo_expression():
|
|||
},
|
||||
}
|
||||
|
||||
cls = ConditionLingo.resolve_condition_class(lingo_dict, version=1.0)
|
||||
cls = ConditionLingo.resolve_condition_class(lingo_dict, version=1)
|
||||
assert cls == JsonApiCondition
|
||||
|
||||
lingo_json = json.dumps(lingo_dict)
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
import pytest
|
||||
from web3.exceptions import Web3Exception
|
||||
|
||||
from nucypher.policy.conditions.base import (
|
||||
AccessControlCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionType,
|
||||
ConditionVariable,
|
||||
OrCompoundCondition,
|
||||
SequentialAccessControlCondition,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_condition_variables(mocker):
|
||||
cond_1 = mocker.Mock(spec=AccessControlCondition)
|
||||
cond_1.verify.return_value = (True, 1)
|
||||
cond_1.to_dict.return_value = {"value": 1}
|
||||
var_1 = ConditionVariable(var_name="var1", condition=cond_1)
|
||||
|
||||
cond_2 = mocker.Mock(spec=AccessControlCondition)
|
||||
cond_2.verify.return_value = (True, 2)
|
||||
cond_2.to_dict.return_value = {"value": 2}
|
||||
var_2 = ConditionVariable(var_name="var2", condition=cond_2)
|
||||
|
||||
cond_3 = mocker.Mock(spec=AccessControlCondition)
|
||||
cond_3.verify.return_value = (True, 3)
|
||||
cond_3.to_dict.return_value = {"value": 3}
|
||||
var_3 = ConditionVariable(var_name="var3", condition=cond_3)
|
||||
|
||||
cond_4 = mocker.Mock(spec=AccessControlCondition)
|
||||
cond_4.verify.return_value = (True, 4)
|
||||
cond_4.to_dict.return_value = {"value": 4}
|
||||
var_4 = ConditionVariable(var_name="var4", condition=cond_4)
|
||||
|
||||
return var_1, var_2, var_3, var_4
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_invalid_sequential_condition(mock_condition_variables):
|
||||
var_1, *_ = mock_condition_variables
|
||||
|
||||
# invalid condition type
|
||||
with pytest.raises(InvalidCondition, match=ConditionType.SEQUENTIAL.value):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_type=ConditionType.TIME.value,
|
||||
condition_variables=list(mock_condition_variables),
|
||||
)
|
||||
|
||||
# no variables
|
||||
with pytest.raises(InvalidCondition, match="At least two conditions"):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=[],
|
||||
)
|
||||
|
||||
# only one variable
|
||||
with pytest.raises(InvalidCondition, match="At least two conditions"):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=[var_1],
|
||||
)
|
||||
|
||||
# too many variables
|
||||
too_many_variables = list(mock_condition_variables)
|
||||
too_many_variables.extend(mock_condition_variables) # duplicate list length
|
||||
assert len(too_many_variables) > SequentialAccessControlCondition.MAX_NUM_CONDITIONS
|
||||
with pytest.raises(InvalidCondition, match="Maximum of"):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=too_many_variables,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_sequential_condition_too_many_nested_levels(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
_ = (
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
var_1,
|
||||
ConditionVariable(
|
||||
"seq_1",
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
var_2,
|
||||
ConditionVariable(
|
||||
"seq_2",
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
var_3,
|
||||
var_4,
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_nested_compound_condition_too_many_nested_levels(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
with pytest.raises(
|
||||
InvalidCondition, match="nested levels of multi-conditions are allowed"
|
||||
):
|
||||
_ = SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
ConditionVariable(
|
||||
"var1",
|
||||
OrCompoundCondition(
|
||||
operands=[
|
||||
var_1.condition,
|
||||
SequentialAccessControlCondition(
|
||||
condition_variables=[
|
||||
var_2,
|
||||
var_3,
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
var_4,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_sequential_condition(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
var_1.condition.verify.return_value = (True, 1)
|
||||
|
||||
var_2.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_1.var_name}"] * 2,
|
||||
)
|
||||
|
||||
var_3.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_2.var_name}"] * 3,
|
||||
)
|
||||
|
||||
var_4.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_3.var_name}"] * 4,
|
||||
)
|
||||
|
||||
sequential_condition = SequentialAccessControlCondition(
|
||||
condition_variables=[var_1, var_2, var_3, var_4],
|
||||
)
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
assert result is True
|
||||
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
|
||||
# only a copy of the context is modified internally
|
||||
assert len(original_context) == 0, "original context remains unchanged"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
|
||||
mock_condition_variables,
|
||||
):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
var_1.condition.verify.return_value = (True, 1)
|
||||
|
||||
var_2.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_1.var_name}"] + 1,
|
||||
)
|
||||
|
||||
var_3.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_1.var_name}"] + context[f":{var_2.var_name}"] + 1,
|
||||
)
|
||||
|
||||
var_4.condition.verify = lambda providers, **context: (
|
||||
True,
|
||||
context[f":{var_1.var_name}"]
|
||||
+ context[f":{var_2.var_name}"]
|
||||
+ context[f":{var_3.var_name}"]
|
||||
+ 1,
|
||||
)
|
||||
|
||||
sequential_condition = SequentialAccessControlCondition(
|
||||
condition_variables=[var_1, var_2, var_3, var_4],
|
||||
)
|
||||
|
||||
expected_var_1_value = 1
|
||||
expected_var_2_value = expected_var_1_value + 1
|
||||
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
|
||||
|
||||
original_context = dict()
|
||||
result, value = sequential_condition.verify(providers={}, **original_context)
|
||||
assert result is True
|
||||
assert value == [
|
||||
expected_var_1_value,
|
||||
expected_var_2_value,
|
||||
expected_var_3_value,
|
||||
(expected_var_1_value + expected_var_2_value + expected_var_3_value + 1),
|
||||
]
|
||||
# only a copy of the context is modified internally
|
||||
assert len(original_context) == 0, "original context remains unchanged"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_skip_schema_validation")
|
||||
def test_sequential_condition_a_call_fails(mock_condition_variables):
|
||||
var_1, var_2, var_3, var_4 = mock_condition_variables
|
||||
|
||||
var_4.condition.verify.side_effect = Web3Exception
|
||||
|
||||
sequential_condition = SequentialAccessControlCondition(
|
||||
condition_variables=[var_1, var_2, var_3, var_4],
|
||||
)
|
||||
|
||||
with pytest.raises(Web3Exception):
|
||||
_ = sequential_condition.verify(providers={})
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.time import TimeCondition
|
||||
from nucypher.policy.conditions.time import TimeCondition, TimeRPCCall
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ def test_invalid_time_condition():
|
|||
condition_type=ConditionType.COMPOUND.value,
|
||||
return_value_test=ReturnValueTest(">", 0),
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
method=TimeCondition.METHOD,
|
||||
method=TimeRPCCall.METHOD,
|
||||
)
|
||||
|
||||
# invalid method
|
||||
|
@ -29,7 +29,7 @@ def test_invalid_time_condition():
|
|||
_ = TimeCondition(
|
||||
return_value_test=ReturnValueTest(">", 0),
|
||||
chain=90210, # Beverly Hills Chain :)
|
||||
method=TimeCondition.METHOD,
|
||||
method=TimeRPCCall.METHOD,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue