Update SequentialCondition to use a list of conditions instead of calls.

pull/3500/head
derekpierre 2024-09-14 13:01:57 -04:00
parent 3f8d8f91a2
commit c6a3a6d72e
No known key found for this signature in database
3 changed files with 127 additions and 176 deletions

View File

@ -23,7 +23,6 @@ from web3 import HTTPProvider
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionCall,
_Serializable,
)
from nucypher.policy.conditions.context import (
@ -35,7 +34,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, ExecutionCallDict, Lingo
from nucypher.policy.conditions.types import ConditionDict, Lingo
from nucypher.policy.conditions.utils import CamelCaseSchema
@ -232,109 +231,68 @@ _COMPARATOR_FUNCTIONS = {
#
# CONDITION = BASE_CONDITION | COMPOUND_CONDITION
#
# EXECUTION_CALL = RPC_CALL | TIME_CALL | CONTRACT_CALL | JSON_API_CALL ...
#
# EXECUTION_VARIABLE = {
# CONDITION_VARIABLE = {
# "varName": STR,
# "call": {
# EXECUTION_CALL
# "condition": {
# CONDITION
# }
# }
#
# SEQUENTIAL_CONDITION = {
# "name": ... (Optional)
# "conditionType": "sequential",
# "variables": [EXECUTION_VARIABLE*]
# "condition": CONDITION
# "conditionVariables": [CONDITION_VARIABLE*]
# }
class _ExecutionCallField(fields.Dict):
"""Serializes/Deserializes Conditions to/from dictionaries"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return value.to_dict()
def _deserialize(self, value, attr, data, **kwargs):
execution_call_dict = value
execution_call_class = self.resolve_execution_call_class(
execution_call_dict=execution_call_dict
)
instance = execution_call_class.from_dict(execution_call_dict)
return instance
@classmethod
def resolve_execution_call_class(cls, execution_call_dict: ExecutionCallDict):
from nucypher.policy.conditions.evm import (
ContractCall,
RPCCall,
)
from nucypher.policy.conditions.offchain import JsonApiCall
from nucypher.policy.conditions.time import TimeRPCCall
call_type = execution_call_dict.get("callType")
for execution_call_type in (RPCCall, TimeRPCCall, ContractCall, JsonApiCall):
if execution_call_type.CALL_TYPE == call_type:
return execution_call_type
raise InvalidConditionLingo(
f"Cannot resolve condition lingo with call type {call_type}"
)
class ExecutionVariable(_Serializable):
class ConditionVariable(_Serializable):
class Schema(CamelCaseSchema):
var_name = fields.Str(required=True)
call = _ExecutionCallField(required=True)
condition = _ConditionField(required=True)
def __init__(self, var_name: str, call: ExecutionCall):
def __init__(self, var_name: str, condition: AccessControlCondition):
self.var_name = var_name
self.call = call
self.condition = condition
class SequentialAccessControlCondition(AccessControlCondition):
CONDITION_TYPE = ConditionType.SEQUENTIAL.value
MAX_NUM_VARIABLES = 5
MAX_NUM_CONDITION_VARIABLES = 5
@classmethod
def _validate_variables(
def _validate_condition_variables(
cls,
variables: List[ExecutionVariable],
condition_variables: List[ConditionVariable],
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
num_variables = len(variables)
if num_variables == 0:
raise exception_class("Must be at least one variable")
num_condition_variables = len(condition_variables)
if num_condition_variables == 0:
raise exception_class("Must be at least one condition variable")
if num_variables > cls.MAX_NUM_VARIABLES:
if num_condition_variables > cls.MAX_NUM_CONDITION_VARIABLES:
raise exception_class(
f"Maximum of {cls.MAX_NUM_VARIABLES} variables are allowed"
f"Maximum of {cls.MAX_NUM_CONDITION_VARIABLES} condition variables are allowed"
)
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL.value), required=True
)
variables = fields.List(
fields.Nested(ExecutionVariable.Schema(), required=True)
condition_variables = fields.List(
fields.Nested(ConditionVariable.Schema(), required=True)
)
name = fields.Str(required=False)
condition = _ConditionField(required=True)
# maintain field declaration ordering
class Meta:
ordered = True
@validates_schema
def validate_calls(self, data, **kwargs):
variables = data["variables"]
SequentialAccessControlCondition._validate_variables(
variables, ValidationError
def validate_condition_variables(self, data, **kwargs):
condition_variables = data["condition_variables"]
SequentialAccessControlCondition._validate_condition_variables(
condition_variables, ValidationError
)
@post_load
@ -343,8 +301,7 @@ class SequentialAccessControlCondition(AccessControlCondition):
def __init__(
self,
variables: List[ExecutionVariable],
condition: AccessControlCondition,
condition_variables: List[ConditionVariable],
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
@ -352,15 +309,16 @@ class SequentialAccessControlCondition(AccessControlCondition):
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_variables(variables=variables, exception_class=InvalidCondition)
self._validate_condition_variables(
condition_variables=condition_variables, exception_class=InvalidCondition
)
self.name = name
self.variables = variables
self.condition_variables = condition_variables
self.condition_type = condition_type
self.condition = condition
def __repr__(self):
r = f"{self.__class__.__name__}(num_vars={len(self.variables)}, condition={self.condition})"
r = f"{self.__class__.__name__}(num_condition_variables={len(self.condition_variables)})"
return r
# TODO - think about not dereferencing context but using a dict;
@ -368,18 +326,22 @@ class SequentialAccessControlCondition(AccessControlCondition):
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
) -> Tuple[bool, Any]:
values = []
latest_success = False
inner_context = dict(context) # don't modify passed in context - use a copy
# resolve variables
for var in self.variables:
result = var.call.execute(providers=providers, **inner_context)
inner_context[f":{var.var_name}"] = result
for condition_variable in self.condition_variables:
latest_success, result = condition_variable.condition.verify(
providers=providers, **inner_context
)
values.append(result)
if not latest_success:
# short circuit due to failed condition
break
# check condition
condition_check, condition_result = self.condition.verify(
providers=providers, **inner_context
)
# TODO should the variable results be included in the overall result?
return condition_check, condition_result
inner_context[f":{condition_variable.var_name}"] = result
return latest_success, values
class ReturnValueTest:

View File

@ -6,13 +6,10 @@ from packaging.version import parse as parse_version
import nucypher
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.evm import ContractCall, RPCCall
from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.policy.conditions.offchain import JsonApiCall
from nucypher.policy.conditions.time import TimeRPCCall
from tests.constants import TESTERCHAIN_CHAIN_ID
@ -36,16 +33,15 @@ def lingo_with_compound_conditions(get_random_checksum_address):
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {"value": 0, "comparator": ">"},
},
{
"conditionType": ConditionType.CONTRACT.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "isPolicyActive",
"parameters": [":hrac"],
"returnValueTest": {"comparator": "==", "value": True},
"contractAddress": get_random_checksum_address(),
"functionAbi": {
"type": "function",
@ -62,6 +58,7 @@ def lingo_with_compound_conditions(get_random_checksum_address):
{"name": "", "type": "bool", "internalType": "bool"}
],
},
"returnValueTest": {"comparator": "==", "value": True},
},
{
"conditionType": ConditionType.COMPOUND.value,
@ -70,34 +67,42 @@ def lingo_with_compound_conditions(get_random_checksum_address):
# sequential condition
{
"conditionType": ConditionType.SEQUENTIAL.value,
"variables": [
"conditionVariables": [
{
"varName": "timeValue",
"call": {
# TimeRPCCall
"callType": TimeRPCCall.CALL_TYPE,
"condition": {
# Time
"conditionType": ConditionType.TIME.value,
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {
"value": 0,
"comparator": ">",
},
},
},
{
"varName": "rpcValue",
"call": {
# RPCCall
"callType": RPCCall.CALL_TYPE,
"condition": {
# RPC
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [
get_random_checksum_address(),
"latest",
],
"returnValueTest": {
"comparator": ">=",
"value": 10000000000000,
},
},
},
{
"varName": "contractValue",
"call": {
# ContractCall
"callType": ContractCall.CALL_TYPE,
"condition": {
# Contract
"conditionType": ConditionType.CONTRACT.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "isPolicyActive",
"parameters": [":hrac"],
@ -121,28 +126,30 @@ def lingo_with_compound_conditions(get_random_checksum_address):
}
],
},
"returnValueTest": {
"comparator": "==",
"value": True,
},
},
},
{
"varName": "jsonValue",
"call": {
# JsonApiCall
"callType": JsonApiCall.CALL_TYPE,
"condition": {
# JSON API
"conditionType": ConditionType.JSONAPI.value,
"endpoint": "https://api.example.com/data",
"query": "$.store.book[0].price",
"parameters": {
"ids": "ethereum",
"vs_currencies": "usd",
},
"returnValueTest": {
"comparator": "==",
"value": 2,
},
},
},
],
"condition": {
"conditionType": ConditionType.TIME.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "blocktime",
"returnValueTest": {"value": 0, "comparator": ">"},
},
},
{
"conditionType": ConditionType.RPC.value,
@ -162,9 +169,9 @@ def lingo_with_compound_conditions(get_random_checksum_address):
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
"returnValueTest": {"value": 0, "comparator": ">"},
},
],
},

View File

@ -3,33 +3,32 @@ from web3.exceptions import Web3Exception
from nucypher.policy.conditions.base import (
AccessControlCondition,
ExecutionCall,
)
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import (
ConditionType,
ExecutionVariable,
ConditionVariable,
SequentialAccessControlCondition,
)
@pytest.fixture(scope="function")
def mock_execution_variables(mocker):
call_1 = mocker.Mock(spec=ExecutionCall)
call_1.execute.return_value = 1
var_1 = ExecutionVariable(var_name="var1", call=call_1)
cond_1 = mocker.Mock(spec=AccessControlCondition)
cond_1.verify.return_value = (True, 1)
var_1 = ConditionVariable(var_name="var1", condition=cond_1)
call_2 = mocker.Mock(spec=ExecutionCall)
call_2.execute.return_value = 2
var_2 = ExecutionVariable(var_name="var2", call=call_2)
cond_2 = mocker.Mock(spec=AccessControlCondition)
cond_2.verify.return_value = (True, 2)
var_2 = ConditionVariable(var_name="var2", condition=cond_2)
call_3 = mocker.Mock(spec=ExecutionCall)
call_3.execute.return_value = 3
var_3 = ExecutionVariable(var_name="var3", call=call_3)
cond_3 = mocker.Mock(spec=AccessControlCondition)
cond_3.verify.return_value = (True, 3)
var_3 = ConditionVariable(var_name="var3", condition=cond_3)
call_4 = mocker.Mock(spec=ExecutionCall)
call_4.execute.return_value = 4
var_4 = ExecutionVariable(var_name="var4", call=call_4)
cond_4 = mocker.Mock(spec=AccessControlCondition)
cond_4.verify.return_value = (True, 4)
var_4 = ConditionVariable(var_name="var4", condition=cond_4)
return var_1, var_2, var_3, var_4
@ -39,56 +38,58 @@ def test_invalid_sequential_condition(mock_execution_variables, rpc_condition):
with pytest.raises(InvalidCondition, match=ConditionType.SEQUENTIAL.value):
_ = SequentialAccessControlCondition(
condition_type=ConditionType.TIME.value,
variables=list(mock_execution_variables),
condition=rpc_condition,
condition_variables=list(mock_execution_variables),
)
# no variables
with pytest.raises(InvalidCondition):
_ = SequentialAccessControlCondition(
condition_type=ConditionType.TIME.value,
variables=[],
condition=rpc_condition,
condition_variables=[],
)
# too many variables
too_many_variables = list(mock_execution_variables)
too_many_variables.extend(mock_execution_variables) # duplicate list length
assert len(too_many_variables) > SequentialAccessControlCondition.MAX_NUM_VARIABLES
assert (
len(too_many_variables)
> SequentialAccessControlCondition.MAX_NUM_CONDITION_VARIABLES
)
with pytest.raises(InvalidCondition):
_ = SequentialAccessControlCondition(
condition_type=ConditionType.TIME.value,
variables=too_many_variables,
condition=rpc_condition,
condition_variables=too_many_variables,
)
def test_sequential_condition(mocker, mock_execution_variables):
var_1, var_2, var_3, var_4 = mock_execution_variables
var_1.call.execute.return_value = 1
var_1.condition.verify.return_value = (True, 1)
var_2.call.execute = lambda providers, **context: context[f":{var_1.var_name}"] * 2
var_3.call.execute = lambda providers, **context: context[f":{var_2.var_name}"] * 3
var_4.call.execute = lambda providers, **context: context[f":{var_3.var_name}"] * 4
condition = mocker.Mock(spec=AccessControlCondition)
condition.verify = lambda providers, **context: (
var_2.condition.verify = lambda providers, **context: (
True,
context[f":{var_4.var_name}"] * 5,
context[f":{var_1.var_name}"] * 2,
)
var_3.condition.verify = lambda providers, **context: (
True,
context[f":{var_2.var_name}"] * 3,
)
var_4.condition.verify = lambda providers, **context: (
True,
context[f":{var_3.var_name}"] * 4,
)
sequential_condition = SequentialAccessControlCondition(
variables=[var_1, var_2, var_3, var_4],
condition=condition,
condition_variables=[var_1, var_2, var_3, var_4],
)
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
assert result is True
assert value == (1 * 2 * 3 * 4 * 5)
assert value == [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
# only a copy of the context is modified internally
assert len(original_context) == 0, "original context remains unchanged"
@ -98,55 +99,43 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
):
var_1, var_2, var_3, var_4 = mock_execution_variables
var_1.call.execute.return_value = 1
var_1.condition.verify.return_value = (True, 1)
var_2.call.execute = lambda providers, **context: context[f":{var_1.var_name}"] + 1
var_3.call.execute = (
lambda providers, **context: context[f":{var_1.var_name}"]
+ context[f":{var_2.var_name}"]
+ 1
var_2.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"] + 1,
)
var_4.call.execute = (
lambda providers, **context: context[f":{var_1.var_name}"]
+ context[f":{var_2.var_name}"]
+ context[f":{var_3.var_name}"]
+ 1
var_3.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"] + context[f":{var_2.var_name}"] + 1,
)
condition = mocker.Mock(spec=AccessControlCondition)
condition.verify = lambda providers, **context: (
var_4.condition.verify = lambda providers, **context: (
True,
context[f":{var_1.var_name}"]
+ context[f":{var_2.var_name}"]
+ context[f":{var_3.var_name}"]
+ context[f":{var_4.var_name}"]
+ 1,
)
sequential_condition = SequentialAccessControlCondition(
variables=[var_1, var_2, var_3, var_4],
condition=condition,
condition_variables=[var_1, var_2, var_3, var_4],
)
expected_var_1_value = 1
expected_var_2_value = expected_var_1_value + 1
expected_var_3_value = expected_var_1_value + expected_var_2_value + 1
expected_var_4_value = (
expected_var_1_value + expected_var_2_value + expected_var_3_value + 1
)
original_context = dict()
result, value = sequential_condition.verify(providers={}, **original_context)
assert result is True
assert value == (
expected_var_1_value
+ expected_var_2_value
+ expected_var_3_value
+ expected_var_4_value
+ 1
)
assert value == [
expected_var_1_value,
expected_var_2_value,
expected_var_3_value,
(expected_var_1_value + expected_var_2_value + expected_var_3_value + 1),
]
# only a copy of the context is modified internally
assert len(original_context) == 0, "original context remains unchanged"
@ -154,17 +143,10 @@ def test_sequential_condition_all_prior_vars_passed_to_subsequent_calls(
def test_sequential_condition_a_call_fails(mocker, mock_execution_variables):
var_1, var_2, var_3, var_4 = mock_execution_variables
var_4.call.execute.side_effect = Web3Exception
condition = mocker.Mock(spec=AccessControlCondition)
condition.verify = lambda providers, **context: (
True,
5,
)
var_4.condition.verify.side_effect = Web3Exception
sequential_condition = SequentialAccessControlCondition(
variables=[var_1, var_2, var_3, var_4],
condition=condition,
condition_variables=[var_1, var_2, var_3, var_4],
)
with pytest.raises(Web3Exception):