Limit number of operands for CompoundAccessControlCondition - still need to handle the case of nested compound conditions.

pull/3500/head
derekpierre 2024-05-01 16:47:54 -04:00
parent 33d382d9cb
commit 77c99ed847
No known key found for this signature in database
3 changed files with 55 additions and 19 deletions

View File

@ -89,6 +89,8 @@ class CompoundAccessControlCondition(AccessControlCondition):
OPERATORS = (AND_OPERATOR, OR_OPERATOR, NOT_OPERATOR)
CONDITION_TYPE = ConditionType.COMPOUND.value
MAX_OPERANDS = 5
@classmethod
def _validate_operator_and_operands(
cls,
@ -99,15 +101,22 @@ class CompoundAccessControlCondition(AccessControlCondition):
if operator not in cls.OPERATORS:
raise exception_class(f"{operator} is not a valid operator")
num_operands = len(operands)
if operator == cls.NOT_OPERATOR:
if len(operands) != 1:
if num_operands != 1:
raise exception_class(
f"Only 1 operand permitted for '{operator}' compound condition"
)
elif len(operands) < 2:
elif num_operands < 2:
raise exception_class(
f"Minimum of 2 operand needed for '{operator}' compound condition"
)
elif num_operands > cls.MAX_OPERANDS:
raise exception_class(
f"Maximum of {cls.MAX_OPERANDS} operands allowed for '{operator}' compound condition"
)
# TODO nested operands
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)

View File

@ -3,7 +3,13 @@ from collections import defaultdict
import pytest
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.policy.conditions.lingo import (
CompoundAccessControlCondition,
ConditionLingo,
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.utils.policy import make_message_kits
@ -14,22 +20,28 @@ def make_multichain_evm_conditions(bob, chain_ids):
"""This is a helper function to make a set of conditions that are valid on multiple chains."""
operands = list()
for chain_id in chain_ids:
operand = [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": chain_id,
},
{
"conditionType": ConditionType.RPC.value,
"chain": chain_id,
"method": "eth_getBalance",
"parameters": [bob.checksum_address, "latest"],
"returnValueTest": {"comparator": ">=", "value": 10000000000000},
},
]
operands.extend(operand)
compound_and_condition = CompoundAccessControlCondition(
operator="and",
operands=[
TimeCondition(
chain=chain_id,
return_value_test=ReturnValueTest(
comparator=">",
value=0,
),
),
RPCCondition(
chain=chain_id,
method="eth_getBalance",
parameters=[bob.checksum_address, "latest"],
return_value_test=ReturnValueTest(
comparator=">=",
value=10000000000000,
),
),
],
)
operands.append(compound_and_condition.to_dict())
_conditions = {
"version": ConditionLingo.VERSION,

View File

@ -86,6 +86,21 @@ def test_invalid_compound_condition(time_condition, rpc_condition):
operands=[rpc_condition],
)
# exceeds max operands
operands = list()
for i in range(CompoundAccessControlCondition.MAX_OPERANDS + 1):
operands.append(rpc_condition)
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.OR_OPERATOR,
operands=operands,
)
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.AND_OPERATOR,
operands=operands,
)
@pytest.mark.parametrize("operator", CompoundAccessControlCondition.OPERATORS)
def test_compound_condition_schema_validation(operator, time_condition, rpc_condition):