Merge pull request #3293 from derekpierre/what-not

Add Not Operator Functionality
pull/3304/head
KPrasch 2023-10-20 11:03:27 +02:00 committed by GitHub
commit 80aa3c5ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 643 additions and 107 deletions

View File

@ -2,7 +2,7 @@ from nucypher.blockchain.eth.domains import TACoDomain
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.characters.chaotic import NiceGuyEddie as _Enrico
from nucypher.characters.chaotic import ThisBobAlwaysDecrypts
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
plaintext = b"paz al amanecer"
THIS_IS_NOT_A_TRINKET = 55 # sometimes called "public key"
@ -16,7 +16,7 @@ ANYTHING_CAN_BE_PASSED_AS_RITUAL_ID = 55
before_the_beginning_of_time = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"chain": 1,
"method": "blocktime",
"returnValueTest": {"comparator": "<", "value": 0},

View File

@ -7,7 +7,7 @@ from nucypher.blockchain.eth.domains import TACoDomain
from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.characters.lawful import Bob, Enrico
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.constants import DEFAULT_TEST_ENRICO_PRIVATE_KEY
@ -56,32 +56,32 @@ print(
conditions = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": 1,
"method": "eth_getBalance",
"parameters": ["0x210eeAC07542F815ebB6FD6689637D8cA2689392", "latest"],
"returnValueTest": {"comparator": "==", "value": 0},
},
{
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": 137,
"method": "eth_getBalance",
"parameters": ["0x210eeAC07542F815ebB6FD6689637D8cA2689392", "latest"],
"returnValueTest": {"comparator": "==", "value": 0},
},
{
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": 5,
"method": "eth_getBalance",
"parameters": ["0x210eeAC07542F815ebB6FD6689637D8cA2689392", "latest"],
"returnValueTest": {"comparator": ">", "value": 1},
},
{
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": 80001,
"method": "eth_getBalance",
"parameters": ["0x210eeAC07542F815ebB6FD6689637D8cA2689392", "latest"],

View File

@ -7,7 +7,7 @@ from nucypher.blockchain.eth.domains import TACoDomain
from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.blockchain.eth.signers import InMemorySigner
from nucypher.characters.lawful import Bob, Enrico
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.utilities.logging import GlobalLoggerSettings
from nucypher.utilities.profiler import Profiler
from tests.constants import DEFAULT_TEST_ENRICO_PRIVATE_KEY
@ -57,7 +57,7 @@ print(
eth_balance_condition = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": 80001,
"method": "eth_getBalance",
"parameters": ["0x210eeAC07542F815ebB6FD6689637D8cA2689392", "latest"],

View File

@ -0,0 +1 @@
Added ``not`` operator functionality to ``CompoundAccessControlCondition`` so that the logical inverse of conditions can be evaluated.

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import fields, post_load, validate, validates_schema
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
@ -151,6 +151,11 @@ class RPCCondition(AccessControlCondition):
_validate_chain(chain=chain)
# internal
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self.condition_type = condition_type
self.name = name
self.chain = chain
@ -265,6 +270,18 @@ class RPCCondition(AccessControlCondition):
class ContractCondition(RPCCondition):
CONDITION_TYPE = ConditionType.CONTRACT.value
@classmethod
def _validate_contract_type_or_function_abi(
cls,
standard_contract_type: str,
function_abi: Dict,
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
if not (bool(standard_contract_type) ^ bool(function_abi)):
raise exception_class(
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
)
class Schema(RPCCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
@ -281,10 +298,9 @@ class ContractCondition(RPCCondition):
def check_standard_contract_type_or_function_abi(self, data, **kwargs):
standard_contract_type = data.get("standard_contract_type")
function_abi = data.get("function_abi")
if not (bool(standard_contract_type) ^ bool(function_abi)):
raise InvalidCondition(
f"Provide 'standardContractType' or 'functionAbi'; got ({standard_contract_type}, {function_abi})."
)
ContractCondition._validate_contract_type_or_function_abi(
standard_contract_type, function_abi, ValidationError
)
def __init__(
self,
@ -296,13 +312,12 @@ class ContractCondition(RPCCondition):
**kwargs,
):
# internal
super().__init__(*args, **kwargs)
super().__init__(condition_type=condition_type, *args, **kwargs)
self.w3 = Web3() # used to instantiate contract function without a provider
if not (bool(standard_contract_type) ^ bool(function_abi)):
raise InvalidCondition(
f"Provide 'standard_contract_type' or 'function_abi'; got ({standard_contract_type}, {function_abi})."
)
ContractCondition._validate_contract_type_or_function_abi(
standard_contract_type, function_abi, InvalidCondition
)
# preprocessing
contract_address = to_checksum_address(contract_address)

View File

@ -3,10 +3,6 @@ class InvalidConditionLingo(Exception):
"""Invalid lingo grammar."""
class InvalidLogicalOperator(InvalidConditionLingo):
"""Invalid definition of logical lingo operator."""
# Connectivity
class NoConnectionToChain(RuntimeError):
"""Raised when a node does not have an associated provider for a chain."""

View File

@ -3,7 +3,7 @@ import base64
import operator as pyoperator
from enum import Enum
from hashlib import md5
from typing import Any, List, Optional, Tuple, Type
from typing import Any, List, Optional, Tuple, Type, Union
from marshmallow import (
Schema,
@ -13,14 +13,15 @@ from marshmallow import (
pre_load,
validate,
validates,
validates_schema,
)
from packaging.version import parse as parse_version
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
from nucypher.policy.conditions.context import is_context_variable
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
InvalidLogicalOperator,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, Lingo
@ -76,24 +77,52 @@ class ConditionType(Enum):
class CompoundAccessControlCondition(AccessControlCondition):
AND_OPERATOR = "and"
OR_OPERATOR = "or"
OPERATORS = (AND_OPERATOR, OR_OPERATOR)
NOT_OPERATOR = "not"
OPERATORS = (AND_OPERATOR, OR_OPERATOR, NOT_OPERATOR)
CONDITION_TYPE = ConditionType.COMPOUND.value
@classmethod
def _validate_operator_and_operands(
cls,
operator: str,
operands: List,
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
if operator not in cls.OPERATORS:
raise exception_class(f"{operator} is not a valid operator")
if operator == cls.NOT_OPERATOR:
if len(operands) != 1:
raise exception_class(
f"Only 1 operand permitted for '{operator}' compound condition"
)
elif len(operands) < 2:
raise exception_class(
f"Minimum of 2 operand needed for '{operator}' compound condition"
)
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False)
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
operands = fields.List(
_ConditionField, required=True, validate=validate.Length(min=2)
)
operator = fields.Str(required=True)
operands = fields.List(_ConditionField, required=True)
# maintain field declaration ordering
class Meta:
ordered = True
@validates_schema
def validate_operator_and_operands(self, data, **kwargs):
operator = data["operator"]
operands = data["operands"]
CompoundAccessControlCondition._validate_operator_and_operands(
operator, operands, ValidationError
)
@post_load
def make(self, data, **kwargs):
return CompoundAccessControlCondition(**data)
@ -111,9 +140,14 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*]
}
"""
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_operator_and_operands(operator, operands, InvalidCondition)
self.condition_type = condition_type
if operator not in self.OPERATORS:
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
self.operator = operator
self.operands = operands
self.condition_type = condition_type
@ -134,12 +168,16 @@ class CompoundAccessControlCondition(AccessControlCondition):
# short-circuit check
if overall_result is False:
break
else:
# or operator
elif self.operator == self.OR_OPERATOR:
overall_result = overall_result or current_result
# short-circuit check
if overall_result is True:
break
elif self.operator == self.NOT_OPERATOR:
return not current_result, current_value
else:
# should never get here; raise just in case
raise ValueError(f"Invalid operator {self.operator}")
return overall_result, values
@ -154,6 +192,11 @@ class AndCompoundCondition(CompoundAccessControlCondition):
super().__init__(operator=self.AND_OPERATOR, operands=operands)
class NotCompoundCondition(CompoundAccessControlCondition):
def __init__(self, operand: AccessControlCondition):
super().__init__(operator=self.NOT_OPERATOR, operands=[operand])
class ReturnValueTest:
class InvalidExpression(ValueError):
pass

View File

@ -37,7 +37,7 @@ class TimeCondition(RPCCondition):
return_value_test: ReturnValueTest,
chain: int,
method: str = METHOD,
condition_type: str = ConditionType.TIME.value,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
if method != self.METHOD:

View File

@ -16,7 +16,7 @@ from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.blockchain.eth.signers import InMemorySigner, Signer
from nucypher.characters.lawful import Bob, Enrico
from nucypher.crypto.powers import TransactingPower
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import GlobalLoggerSettings
from tests.constants import DEFAULT_TEST_ENRICO_PRIVATE_KEY, GLOBAL_ALLOW_LIST
@ -303,7 +303,7 @@ def nucypher_dkg(
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": application_agent.blockchain.client.chain_id,

View File

@ -5,7 +5,7 @@ from twisted.internet.threads import deferToThread
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.blockchain.eth.trackers.dkg import EventScannerTask
from nucypher.characters.lawful import Enrico, Ursula
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID
# constants
@ -21,7 +21,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -24,7 +24,12 @@ from nucypher.policy.conditions.exceptions import (
RequiredContextVariable,
RPCExecutionFailed,
)
from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest
from nucypher.policy.conditions.lingo import (
ConditionLingo,
ConditionType,
NotCompoundCondition,
ReturnValueTest,
)
from tests.constants import (
TEST_ETH_PROVIDER_URI,
TEST_POLYGON_PROVIDER_URI,
@ -492,16 +497,47 @@ def test_time_condition_evaluation(testerchain, time_condition, condition_provid
assert condition_result is True
def test_simple_compound_conditions_evaluation(
def test_not_time_condition_evaluation(
testerchain, time_condition, condition_providers
):
not_condition = NotCompoundCondition(operand=time_condition)
condition_result, call_value = time_condition.verify(providers=condition_providers)
assert condition_result is True
not_condition_result, not_call_value = not_condition.verify(
providers=condition_providers
)
assert not_condition_result is (not condition_result)
assert not_call_value == call_value
def test_simple_compound_conditions_lingo_evaluation(
testerchain, compound_blocktime_lingo, condition_providers
):
# TODO Improve internals of evaluation here (natural vs recursive approach)
conditions = json.dumps(compound_blocktime_lingo)
lingo = ConditionLingo.from_json(conditions)
result = lingo.eval(providers=condition_providers)
assert result is True
def test_not_of_simple_compound_conditions_lingo_evaluation(
testerchain, compound_blocktime_lingo, condition_providers
):
# evaluate base condition
access_condition_lingo = ConditionLingo.from_dict(compound_blocktime_lingo)
result = access_condition_lingo.eval(providers=condition_providers)
assert result is True
# evaluate not of base condition
not_access_condition = NotCompoundCondition(
operand=access_condition_lingo.condition
)
not_access_condition_lingo = ConditionLingo(condition=not_access_condition)
not_result = not_access_condition_lingo.eval(providers=condition_providers)
assert not_result is False
assert not_result is (not result)
@mock.patch(
"nucypher.policy.conditions.evm.get_context_value",
side_effect=_dont_validate_user_address,
@ -517,23 +553,46 @@ def test_onchain_conditions_lingo_evaluation(
assert result is True
@mock.patch(
"nucypher.policy.conditions.evm.get_context_value",
side_effect=_dont_validate_user_address,
)
def test_not_of_onchain_conditions_lingo_evaluation(
get_context_value_mock,
testerchain,
compound_lingo,
condition_providers,
):
context = {USER_ADDRESS_CONTEXT: {"address": testerchain.etherbase_account}}
result = compound_lingo.eval(providers=condition_providers, **context)
assert result is True
not_condition = NotCompoundCondition(operand=compound_lingo.condition)
not_access_condition_lingo = ConditionLingo(condition=not_condition)
not_result = not_access_condition_lingo.eval(
providers=condition_providers, **context
)
assert not_result is False
assert not_result is (not result)
def test_single_retrieve_with_onchain_conditions(enacted_policy, bob, ursulas):
bob.remember_node(ursulas[0])
bob.start_learning_loop()
conditions = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "rpc",
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [bob.checksum_address, "latest"],

View File

@ -34,7 +34,11 @@ from nucypher.crypto.keystore import Keystore
from nucypher.network.nodes import TEACHER_NODES
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest
from nucypher.policy.conditions.lingo import (
ConditionLingo,
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition
from nucypher.policy.payment import SubscriptionManagerPayment
from nucypher.utilities.emitters import StdoutEmitter
@ -592,17 +596,17 @@ def compound_blocktime_lingo():
return {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {
"value": "99999999999999999",
"comparator": "<",
@ -611,7 +615,7 @@ def compound_blocktime_lingo():
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -12,7 +12,7 @@ from web3.datastructures import AttributeDict
from nucypher.blockchain.eth.agents import CoordinatorAgent
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.characters.lawful import Enrico, Ursula
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID
from tests.mock.coordinator import MockCoordinatorAgent
from tests.mock.interfaces import MockBlockchain
@ -22,7 +22,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -4,8 +4,17 @@ import pytest
from nucypher_core import Conditions
from nucypher.characters.lawful import Ursula
from nucypher.policy.conditions.exceptions import *
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
ContextVariableVerificationFailed,
InvalidCondition,
InvalidConditionLingo,
InvalidContextVariableData,
NoConnectionToChain,
RequiredContextVariable,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID
from tests.utils.middleware import MockRestMiddleware
@ -28,17 +37,17 @@ def test_single_retrieve_with_truthy_conditions(enacted_policy, bob, ursulas, mo
conditions = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -72,7 +81,7 @@ def test_single_retrieve_with_falsy_conditions(enacted_policy, bob, ursulas, moc
{
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -134,7 +143,7 @@ def test_middleware_handling_of_failed_condition_responses(
{
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -8,7 +8,7 @@ from nucypher.characters.chaotic import (
ThisBobAlwaysFails,
)
from nucypher.characters.lawful import Ursula
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import (
MOCK_ETH_PROVIDER_URI,
MOCK_REGISTRY_FILEPATH,
@ -30,7 +30,7 @@ def _attempt_decryption(BobClass, plaintext, testerchain):
definitely_false_condition = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "blocktime",
"returnValueTest": {"comparator": "<", "value": 0},

View File

@ -22,7 +22,7 @@ from nucypher.crypto.powers import (
TLSHostingPower,
)
from nucypher.network.server import ProxyRESTServer
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from nucypher.policy.payment import SubscriptionManagerPayment
from nucypher.utilities.networking import LOOPBACK_ADDRESS
from tests.constants import (
@ -182,7 +182,7 @@ def test_ritualist(temp_dir_path, testerchain, dkg_public_key):
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -42,6 +42,7 @@ def registry_2(domain_2, test_registry):
return ContractRegistry(MockRegistrySource(domain=domain_2))
@pytest.mark.skip("inconsistent behaviour on CI - see #3289")
def test_learner_learns_about_domains_separately(
lonely_ursula_maker, domain_1, domain_2, registry_1, registry_2, caplog
):

View File

@ -3,7 +3,11 @@ from unittest.mock import Mock
import pytest
from nucypher.policy.conditions.base import AccessControlCondition
from nucypher.policy.conditions.lingo import AndCompoundCondition, OrCompoundCondition
from nucypher.policy.conditions.lingo import (
AndCompoundCondition,
NotCompoundCondition,
OrCompoundCondition,
)
@pytest.fixture(scope="function")
@ -203,3 +207,160 @@ def test_nested_compound_condition(mock_conditions):
assert result is False
assert len(value) == 2, "or_condition and condition_4"
assert value == [[1, [2, 3]], 4]
def test_not_compound_condition(mock_conditions):
condition_1, condition_2, condition_3, condition_4 = mock_conditions
not_condition = NotCompoundCondition(operand=condition_1)
#
# simple `not`
#
condition_1.verify.return_value = (True, 1)
result, value = not_condition.verify()
assert result is False
assert value == 1
condition_1.verify.return_value = (False, 2)
result, value = not_condition.verify()
assert result is True
assert value == 2
#
# `not` of `or` condition
#
# only True
condition_1.verify.return_value = (True, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (True, 3)
or_condition = OrCompoundCondition(
operands=[
condition_1,
condition_2,
condition_3,
]
)
not_condition = NotCompoundCondition(operand=or_condition)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not or_result)
assert value == or_value
# only False
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (False, 3)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
assert result is True
assert result is (not or_result)
assert value == or_value
# mixture of True/False
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (True, 3)
or_result, or_value = or_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not or_result)
assert value == or_value
#
# `not` of `and` condition
#
# only True
condition_1.verify.return_value = (True, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (True, 3)
and_condition = AndCompoundCondition(
operands=[
condition_1,
condition_2,
condition_3,
]
)
not_condition = NotCompoundCondition(operand=and_condition)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not and_result)
assert value == and_value
# only False
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (False, 2)
condition_3.verify.return_value = (False, 3)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
assert result is True
assert result is (not and_result)
assert value == and_value
# mixture of True/False
condition_1.verify.return_value = (False, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (False, 3)
and_result, and_value = and_condition.verify()
result, value = not_condition.verify()
assert result is True
assert result is (not and_result)
assert value == and_value
#
# Complex nested `or` and `and` (reused nested compound condition in previous test)
#
nested_compound_condition = AndCompoundCondition(
operands=[
OrCompoundCondition(
operands=[
condition_1,
AndCompoundCondition(
operands=[
condition_2,
condition_3,
]
),
]
),
condition_4,
]
)
not_condition = NotCompoundCondition(operand=nested_compound_condition)
# reset all conditions to True
condition_1.verify.return_value = (True, 1)
condition_2.verify.return_value = (True, 2)
condition_3.verify.return_value = (True, 3)
condition_4.verify.return_value = (True, 4)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not nested_result)
assert value == nested_value
# set condition_1 to False so nested and-condition must be evaluated
condition_1.verify.return_value = (False, 1)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is False
assert result is (not nested_result)
assert value == nested_value
# set condition_4 to False so that overall result flips to False, so `not` is now True
condition_4.verify.return_value = (False, 4)
nested_result, nested_value = nested_compound_condition.verify()
result, value = not_condition.verify()
assert result is True
assert result is (not nested_result)
assert value == nested_value

View File

@ -6,7 +6,9 @@ from packaging.version import parse as parse_version
import nucypher
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.exceptions import InvalidConditionLingo
from nucypher.policy.conditions.exceptions import (
InvalidConditionLingo,
)
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID
@ -14,7 +16,7 @@ from tests.constants import TESTERCHAIN_CHAIN_ID
@pytest.fixture(scope="module")
def lingo_with_condition():
return {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -22,24 +24,75 @@ def lingo_with_condition():
@pytest.fixture(scope="module")
def lingo_with_compound_condition():
def lingo_with_compound_conditions(get_random_checksum_address):
return {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime",
"conditionType": ConditionType.CONTRACT.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "isPolicyActive",
"parameters": [":hrac"],
"returnValueTest": {"comparator": "==", "value": True},
"contractAddress": get_random_checksum_address(),
"functionAbi": {
"type": "function",
"name": "isPolicyActive",
"stateMutability": "view",
"inputs": [
{
"name": "_policyID",
"type": "bytes16",
"internalType": "bytes16",
}
],
"outputs": [
{"name": "", "type": "bool", "internalType": "bool"}
],
},
},
{
"conditionType": ConditionType.COMPOUND.value,
"operator": "or",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": ConditionType.RPC.value,
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [get_random_checksum_address(), "latest"],
"returnValueTest": {
"comparator": ">=",
"value": "10000000000000",
},
},
],
},
{
"conditionType": ConditionType.COMPOUND.value,
"operator": "not",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
],
},
],
},
@ -64,15 +117,34 @@ def test_invalid_condition():
}
)
# < 2 operands for and condition
invalid_operator_position_lingo = {
# invalid operator
invalid_operator = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"conditionType": ConditionType.COMPOUND.value,
"operator": "xTrue",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
],
},
}
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(invalid_operator)
# < 2 operands for and condition
invalid_and_operands_lingo = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": ConditionType.COMPOUND.value,
"operator": "and",
"operands": [
{
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -81,7 +153,51 @@ def test_invalid_condition():
},
}
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(invalid_operator_position_lingo)
ConditionLingo.from_dict(invalid_and_operands_lingo)
# < 2 operands for or condition
invalid_or_operands_lingo = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": ConditionType.COMPOUND.value,
"operator": "or",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
}
],
},
}
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(invalid_or_operands_lingo)
# > 1 operand for `not` condition
invalid_not_operands_lingo = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": ConditionType.COMPOUND.value,
"operator": "not",
"operands": [
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
],
},
}
with pytest.raises(InvalidConditionLingo):
ConditionLingo.from_dict(invalid_not_operands_lingo)
@pytest.mark.parametrize("case", ["major", "minor", "patch"])
@ -102,7 +218,7 @@ def test_invalid_condition_version(case):
lingo_dict = {
"version": newer_version_string,
"condition": {
"conditionType": "time",
"conditionType": ConditionType.TIME.value,
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -118,24 +234,24 @@ def test_invalid_condition_version(case):
_ = ConditionLingo.from_dict(lingo_dict)
def test_condition_lingo_to_from_dict(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
def test_condition_lingo_to_from_dict(lingo_with_compound_conditions):
clingo = ConditionLingo.from_dict(lingo_with_compound_conditions)
clingo_dict = clingo.to_dict()
assert clingo_dict == lingo_with_compound_condition
assert clingo_dict == lingo_with_compound_conditions
def test_condition_lingo_to_from_json(lingo_with_compound_condition):
def test_condition_lingo_to_from_json(lingo_with_compound_conditions):
# A bit more convoluted because fields aren't
# necessarily ordered - so string comparison is tricky
clingo_from_dict = ConditionLingo.from_dict(lingo_with_compound_condition)
clingo_from_dict = ConditionLingo.from_dict(lingo_with_compound_conditions)
lingo_json = clingo_from_dict.to_json()
clingo_from_json = ConditionLingo.from_json(lingo_json)
assert clingo_from_json.to_dict() == lingo_with_compound_condition
assert clingo_from_json.to_dict() == lingo_with_compound_conditions
def test_condition_lingo_repr(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
def test_condition_lingo_repr(lingo_with_compound_conditions):
clingo = ConditionLingo.from_dict(lingo_with_compound_conditions)
clingo_string = f"{clingo}"
assert f"{clingo.__class__.__name__}" in clingo_string
assert f"version={ConditionLingo.VERSION}" in clingo_string

View File

@ -2,12 +2,26 @@ import pytest
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ReturnValueTest
from nucypher.policy.conditions.lingo import (
CompoundAccessControlCondition,
ConditionType,
ReturnValueTest,
)
from nucypher.policy.conditions.time import TimeCondition
from tests.constants import TESTERCHAIN_CHAIN_ID
def test_invalid_time_condition():
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.TIME.value):
_ = TimeCondition(
condition_type=ConditionType.COMPOUND.value,
return_value_test=ReturnValueTest(">", 0),
chain=TESTERCHAIN_CHAIN_ID,
method=TimeCondition.METHOD,
)
# invalid method
with pytest.raises(InvalidCondition):
_ = TimeCondition(
return_value_test=ReturnValueTest('>', 0),
@ -17,6 +31,16 @@ def test_invalid_time_condition():
def test_invalid_rpc_condition():
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.RPC.value):
_ = RPCCondition(
condition_type=ConditionType.TIME.value,
method="eth_getBalance",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest("==", 0),
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
)
# no eth_ prefix for method
with pytest.raises(InvalidCondition):
_ = RPCCondition(
@ -46,30 +70,42 @@ def test_invalid_rpc_condition():
def test_invalid_contract_condition():
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.CONTRACT.value):
_ = ContractCondition(
condition_type=ConditionType.RPC.value,
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="balanceOf",
chain=TESTERCHAIN_CHAIN_ID,
standard_contract_type="ERC20",
return_value_test=ReturnValueTest("!=", 0),
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
)
# no abi or contract type
with pytest.raises(InvalidCondition):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest('!=', 0),
parameters=[
':hrac',
]
)
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest("!=", 0),
parameters=[
":hrac",
],
)
# invalid contract type
# invalid standard contract type
with pytest.raises(InvalidCondition):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
standard_contract_type="ERC90210", # Beverly Hills contract type :)
return_value_test=ReturnValueTest('!=', 0),
parameters=[
':hrac',
]
)
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
standard_contract_type="ERC90210", # Beverly Hills contract type :)
return_value_test=ReturnValueTest("!=", 0),
parameters=[
":hrac",
],
)
# invalid ABI
with pytest.raises(InvalidCondition):
@ -112,6 +148,57 @@ def test_invalid_contract_condition():
)
def test_invalid_compound_condition(time_condition, rpc_condition):
for operator in CompoundAccessControlCondition.OPERATORS:
if operator == CompoundAccessControlCondition.NOT_OPERATOR:
operands = [time_condition]
else:
operands = [time_condition, rpc_condition]
# invalid condition type
with pytest.raises(InvalidCondition, match=ConditionType.COMPOUND.value):
_ = CompoundAccessControlCondition(
condition_type=ConditionType.TIME.value,
operator=operator,
operands=operands,
)
# invalid operator - 1 operand
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(operator="5True", operands=[time_condition])
# invalid operator - 2 operands
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator="5True", operands=[time_condition, rpc_condition]
)
# no operands
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(operator=operator, operands=[])
# > 1 operand for not operator
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.NOT_OPERATOR,
operands=[time_condition, rpc_condition],
)
# < 2 operands for or operator
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.OR_OPERATOR,
operands=[time_condition],
)
# < 2 operands for and operator
with pytest.raises(InvalidCondition):
_ = CompoundAccessControlCondition(
operator=CompoundAccessControlCondition.AND_OPERATOR,
operands=[rpc_condition],
)
def test_time_condition_schema_validation(time_condition):
condition_dict = time_condition.to_dict()
@ -229,3 +316,47 @@ def test_contract_condition_schema_validation():
condition_dict = contract_condition.to_dict()
del condition_dict["returnValueTest"]
ContractCondition.validate(condition_dict)
@pytest.mark.parametrize("operator", CompoundAccessControlCondition.OPERATORS)
def test_compound_condition_schema_validation(operator, time_condition, rpc_condition):
if operator == CompoundAccessControlCondition.NOT_OPERATOR:
operands = [time_condition]
else:
operands = [time_condition, rpc_condition]
compound_condition = CompoundAccessControlCondition(
operator=operator, operands=operands
)
compound_condition_dict = compound_condition.to_dict()
# no issues here
CompoundAccessControlCondition.validate(compound_condition_dict)
# no issues with optional name
compound_condition_dict["name"] = "my_contract_condition"
CompoundAccessControlCondition.validate(compound_condition_dict)
with pytest.raises(InvalidCondition):
# incorrect condition type
compound_condition_dict = compound_condition.to_dict()
compound_condition_dict["condition_type"] = ConditionType.RPC.value
CompoundAccessControlCondition.validate(compound_condition_dict)
with pytest.raises(InvalidCondition):
# invalid operator
compound_condition_dict = compound_condition.to_dict()
compound_condition_dict["operator"] = "5True"
CompoundAccessControlCondition.validate(compound_condition_dict)
with pytest.raises(InvalidCondition):
# no operator
compound_condition_dict = compound_condition.to_dict()
del compound_condition_dict["operator"]
CompoundAccessControlCondition.validate(compound_condition_dict)
with pytest.raises(InvalidCondition):
# no operands
compound_condition_dict = compound_condition.to_dict()
del compound_condition_dict["operands"]
CompoundAccessControlCondition.validate(compound_condition_dict)