Merge pull request #3201 from nucypher/condition-type

Add `condition_type` to condition schemas
pull/3204/head
LunarBytes 2023-08-24 11:44:01 +02:00 committed by GitHub
commit f6ee932dac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 157 additions and 44 deletions

View File

@ -0,0 +1 @@
Add a mandatory condition_type field to condition schemas

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import fields, post_load, validates_schema
from marshmallow import fields, post_load, validate, validates_schema
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
@ -18,7 +18,7 @@ from nucypher.policy.conditions.exceptions import (
NoConnectionToChain,
RPCExecutionFailed,
)
from nucypher.policy.conditions.lingo import ReturnValueTest
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake
# TODO: Move this to a more appropriate location,
@ -108,17 +108,19 @@ def _validate_chain(chain: int) -> None:
class RPCCondition(AccessControlCondition):
ETH_PREFIX = "eth_"
ALLOWED_METHODS = (
# RPC
"eth_getBalance",
) # TODO other allowed methods (tDEC #64)
LOG = logging.Logger(__name__)
CONDITION_TYPE = ConditionType.RPC.value
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
name = fields.Str(required=False)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(required=True)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
@ -139,6 +141,7 @@ class RPCCondition(AccessControlCondition):
chain: int,
method: str,
return_value_test: ReturnValueTest,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
parameters: Optional[List[Any]] = None,
):
@ -147,6 +150,7 @@ class RPCCondition(AccessControlCondition):
_validate_chain(chain=chain)
# internal
self.condition_type = condition_type
self.name = name
self.chain = chain
self.provider: Optional[BaseProvider] = None # set in _configure_provider
@ -258,7 +262,12 @@ class RPCCondition(AccessControlCondition):
class ContractCondition(RPCCondition):
CONDITION_TYPE = ConditionType.CONTRACT.value
class Schema(RPCCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
)
standard_contract_type = fields.Str(required=False)
contract_address = fields.Str(required=True)
function_abi = fields.Dict(required=False)
@ -279,6 +288,7 @@ class ContractCondition(RPCCondition):
def __init__(
self,
contract_address: ChecksumAddress,
condition_type: str = CONDITION_TYPE,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
@ -298,6 +308,7 @@ class ContractCondition(RPCCondition):
# spec
self.contract_address = contract_address
self.condition_type = condition_type
self.standard_contract_type = standard_contract_type
self.function_abi = function_abi
self.contract_function = self._get_unbound_contract_function()

View File

@ -1,6 +1,7 @@
import ast
import base64
import operator as pyoperator
from enum import Enum
from hashlib import md5
from typing import Any, List, Optional, Tuple, Type
@ -57,13 +58,32 @@ class _ConditionField(fields.Dict):
# }
class ConditionType(Enum):
"""
Defines the types of conditions that can be evaluated.
"""
TIME = "time"
CONTRACT = "contract"
RPC = "rpc"
COMPOUND = "compound"
@classmethod
def values(cls) -> List[str]:
return [condition.value for condition in cls]
class CompoundAccessControlCondition(AccessControlCondition):
AND_OPERATOR = "and"
OR_OPERATOR = "or"
OPERATORS = (AND_OPERATOR, OR_OPERATOR)
CONDITION_TYPE = ConditionType.COMPOUND.value
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False)
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
operands = fields.List(
@ -82,6 +102,7 @@ class CompoundAccessControlCondition(AccessControlCondition):
self,
operator: str,
operands: List[AccessControlCondition],
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None,
):
"""
@ -90,10 +111,12 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*]
}
"""
self.condition_type = condition_type
if operator not in self.OPERATORS:
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
self.operator = operator
self.operands = operands
self.condition_type = condition_type
self.name = name
self.id = md5(bytes(self)).hexdigest()[:6]
@ -322,24 +345,19 @@ class ConditionLingo(_Serializable):
from nucypher.policy.conditions.time import TimeCondition
# version logical adjustments can be made here as required
# Inspect
method = condition.get("method")
operator = condition.get("operator")
contract = condition.get("contractAddress")
# Resolve
if method:
if method == TimeCondition.METHOD:
return TimeCondition
elif contract:
return ContractCondition
elif method in RPCCondition.ALLOWED_METHODS:
return RPCCondition
elif operator:
return CompoundAccessControlCondition
condition_type = condition.get("conditionType")
for condition in (
TimeCondition,
ContractCondition,
RPCCondition,
CompoundAccessControlCondition,
):
if condition.CONDITION_TYPE == condition_type:
return condition
raise InvalidConditionLingo(
f"Cannot resolve condition lingo type from data {condition}"
f"Cannot resolve condition lingo with condition type {condition_type}"
)
@classmethod

View File

@ -1,18 +1,22 @@
from typing import Any, List, Optional
from marshmallow import fields, post_load
from marshmallow import fields, post_load, validate
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ReturnValueTest
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema
class TimeCondition(RPCCondition):
METHOD = "blocktime"
CONDITION_TYPE = ConditionType.TIME.value
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.TIME.value), required=True
)
name = fields.Str(required=False)
chain = fields.Int(required=True)
method = fields.Str(dump_default="blocktime", required=True)
@ -33,6 +37,7 @@ class TimeCondition(RPCCondition):
return_value_test: ReturnValueTest,
chain: int,
method: str = METHOD,
condition_type: str = ConditionType.TIME.value,
name: Optional[str] = None,
):
if method != self.METHOD:
@ -40,7 +45,11 @@ class TimeCondition(RPCCondition):
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
)
super().__init__(
chain=chain, method=method, return_value_test=return_value_test, name=name
chain=chain,
method=method,
return_value_test=return_value_test,
name=name,
condition_type=condition_type,
)
def validate_method(self, method):

View File

@ -39,6 +39,7 @@ class _AccessControlCondition(TypedDict):
class RPCConditionDict(_AccessControlCondition):
conditionType: str
chain: int
method: str
parameters: NotRequired[List[Any]]
@ -63,6 +64,7 @@ class ContractConditionDict(RPCConditionDict):
#
#
class CompoundConditionDict(TypedDict):
conditionType: str
operator: Literal["and", "or"]
operands: List["Lingo"]

View File

@ -21,6 +21,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -521,14 +521,17 @@ def test_single_retrieve_with_onchain_conditions(enacted_policy, bob, ursulas):
conditions = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"operator": "and",
"operands": [
{
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "rpc",
"chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance",
"parameters": [bob.checksum_address, "latest"],

View File

@ -1,8 +1,13 @@
{
"customABIMultipleParameters" : {
"customABIMultipleParameters": {
"conditionType": "contract",
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
"method": "isSubscribedToToken",
"parameters": [":userAddress", "subscriptionCode", 4],
"parameters": [
":userAddress",
"subscriptionCode",
4
],
"functionAbi": {
"inputs": [
{
@ -72,6 +77,7 @@
}
},
"TStaking": {
"conditionType": "contract",
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
"chain": 1,
"method": "stakes",
@ -114,6 +120,7 @@
}
},
"SubscriptionManagerPayment": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"chain": 137,
"method": "isValidPolicy",
@ -126,6 +133,7 @@
}
},
"ERC1155_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC1155",
"chain": 1,
@ -140,20 +148,32 @@
}
},
"ERC1155_balance_batch": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC1155",
"chain": 1,
"method": "balanceOfBatch",
"parameters": [
[":userAddress",":userAddress",":userAddress",":userAddress"],
[1,2,10003,10004]
[
":userAddress",
":userAddress",
":userAddress",
":userAddress"
],
[
1,
2,
10003,
10004
]
],
"returnValueTest": {
"comparator": ">",
"value": 0
}
},
"ERC721_ownership": {
"ERC721_ownership": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC721",
"chain": 1,
@ -166,7 +186,8 @@
"value": ":userAddress"
}
},
"ERC721_balance": {
"ERC721_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC721",
"chain": 1,
@ -179,7 +200,8 @@
"value": 0
}
},
"ERC20_balance": {
"ERC20_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC20",
"chain": 1,
@ -192,7 +214,8 @@
"value": 0
}
},
"ETH_balance": {
"ETH_balance": {
"conditionType": "contract",
"contractAddress": "",
"standardContractType": "",
"chain": 1,
@ -206,7 +229,8 @@
"value": 10000000000000
}
},
"specific_wallet_address": {
"specific_wallet_address": {
"conditionType": "contract",
"contractAddress": "",
"standardContractType": "",
"chain": 1,
@ -219,12 +243,15 @@
"value": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118"
}
},
"timestamp": {
"timestamp": {
"conditionType": "contract",
"contractAddress": "",
"standardContractType": "timestamp",
"chain": 1,
"method": "eth_getBlockByNumber",
"parameters": ["latest"],
"parameters": [
"latest"
],
"returnValueTest": {
"comparator": ">=",
"value": 1234567890

View File

@ -612,14 +612,17 @@ def compound_blocktime_lingo():
return {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"operator": "and",
"operands": [
{
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"returnValueTest": {
"value": "99999999999999999",
"comparator": "<",
@ -628,6 +631,7 @@ def compound_blocktime_lingo():
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -22,6 +22,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -28,14 +28,17 @@ def test_single_retrieve_with_truthy_conditions(enacted_policy, bob, ursulas, mo
conditions = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"operator": "and",
"operands": [
{
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -69,6 +72,7 @@ def test_single_retrieve_with_falsy_conditions(enacted_policy, bob, ursulas, moc
{
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -130,6 +134,7 @@ def test_middleware_handling_of_failed_condition_responses(
{
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -29,6 +29,7 @@ def _attempt_decryption(BobClass, plaintext):
definitely_false_condition = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"chain": TESTERCHAIN_CHAIN_ID,
"method": "blocktime",
"returnValueTest": {"comparator": "<", "value": 0},

View File

@ -182,6 +182,7 @@ def test_ritualist(temp_dir_path, testerchain, dkg_public_key):
CONDITIONS = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,

View File

@ -7,23 +7,36 @@ import nucypher
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.exceptions import InvalidConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID
@pytest.fixture(scope='module')
def lingo():
@pytest.fixture(scope="module")
def lingo_with_condition():
return {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
}
@pytest.fixture(scope="module")
def lingo_with_compound_condition():
return {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"operator": "and",
"operands": [
{
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
},
{
"conditionType": "time",
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -55,9 +68,11 @@ def test_invalid_condition():
invalid_operator_position_lingo = {
"version": ConditionLingo.VERSION,
"condition": {
"conditionType": "compound",
"operator": "and",
"operands": [
{
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -87,6 +102,7 @@ def test_invalid_condition_version(case):
lingo_dict = {
"version": newer_version_string,
"condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
@ -102,24 +118,24 @@ def test_invalid_condition_version(case):
_ = ConditionLingo.from_dict(lingo_dict)
def test_condition_lingo_to_from_dict(lingo):
clingo = ConditionLingo.from_dict(lingo)
def test_condition_lingo_to_from_dict(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
clingo_dict = clingo.to_dict()
assert clingo_dict == lingo
assert clingo_dict == lingo_with_compound_condition
def test_condition_lingo_to_from_json(lingo):
def test_condition_lingo_to_from_json(lingo_with_compound_condition):
# A bit more convoluted because fields aren't
# necessarily ordered - so string comparison is tricky
clingo_from_dict = ConditionLingo.from_dict(lingo)
clingo_from_dict = ConditionLingo.from_dict(lingo_with_compound_condition)
lingo_json = clingo_from_dict.to_json()
clingo_from_json = ConditionLingo.from_json(lingo_json)
assert clingo_from_json.to_dict() == lingo
assert clingo_from_json.to_dict() == lingo_with_compound_condition
def test_condition_lingo_repr(lingo):
clingo = ConditionLingo.from_dict(lingo)
def test_condition_lingo_repr(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
clingo_string = f"{clingo}"
assert f"{clingo.__class__.__name__}" in clingo_string
assert f"version={ConditionLingo.VERSION}" in clingo_string
@ -144,3 +160,16 @@ def test_lingo_parameter_int_type_preservation(custom_abi_with_multiple_paramete
clingo = ConditionLingo.from_json(clingo_json)
conditions = clingo.to_dict()
assert conditions["condition"]["parameters"][2] == 4
def test_lingo_resolves_condition_type(lingo_with_condition):
for condition_type in ConditionType.values():
lingo_with_condition["conditionType"] = condition_type
ConditionLingo.resolve_condition_class(lingo_with_condition)
def test_lingo_rejects_invalid_condition_type(lingo_with_condition):
for condition_type in ["invalid", "", None]:
lingo_with_condition["conditionType"] = condition_type
with pytest.raises(InvalidConditionLingo):
ConditionLingo.resolve_condition_class(lingo_with_condition)