Merge pull request #3201 from nucypher/condition-type

Add `condition_type` to condition schemas
pull/3204/head
LunarBytes 2023-08-24 11:44:01 +02:00 committed by GitHub
commit f6ee932dac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 157 additions and 44 deletions

View File

@ -0,0 +1 @@
Add a mandatory condition_type field to condition schemas

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
from eth_typing import ChecksumAddress from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address from eth_utils import to_checksum_address
from marshmallow import fields, post_load, validates_schema from marshmallow import fields, post_load, validate, validates_schema
from web3 import HTTPProvider, Web3 from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware from web3.middleware import geth_poa_middleware
@ -18,7 +18,7 @@ from nucypher.policy.conditions.exceptions import (
NoConnectionToChain, NoConnectionToChain,
RPCExecutionFailed, RPCExecutionFailed,
) )
from nucypher.policy.conditions.lingo import ReturnValueTest from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake
# TODO: Move this to a more appropriate location, # TODO: Move this to a more appropriate location,
@ -108,17 +108,19 @@ def _validate_chain(chain: int) -> None:
class RPCCondition(AccessControlCondition): class RPCCondition(AccessControlCondition):
ETH_PREFIX = "eth_" ETH_PREFIX = "eth_"
ALLOWED_METHODS = ( ALLOWED_METHODS = (
# RPC # RPC
"eth_getBalance", "eth_getBalance",
) # TODO other allowed methods (tDEC #64) ) # TODO other allowed methods (tDEC #64)
LOG = logging.Logger(__name__) LOG = logging.Logger(__name__)
CONDITION_TYPE = ConditionType.RPC.value
class Schema(CamelCaseSchema): class Schema(CamelCaseSchema):
SKIP_VALUES = (None,) SKIP_VALUES = (None,)
name = fields.Str(required=False) name = fields.Str(required=False)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(required=True) chain = fields.Int(required=True)
method = fields.Str(required=True) method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False) parameters = fields.List(fields.Field, attribute="parameters", required=False)
@ -139,6 +141,7 @@ class RPCCondition(AccessControlCondition):
chain: int, chain: int,
method: str, method: str,
return_value_test: ReturnValueTest, return_value_test: ReturnValueTest,
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None, name: Optional[str] = None,
parameters: Optional[List[Any]] = None, parameters: Optional[List[Any]] = None,
): ):
@ -147,6 +150,7 @@ class RPCCondition(AccessControlCondition):
_validate_chain(chain=chain) _validate_chain(chain=chain)
# internal # internal
self.condition_type = condition_type
self.name = name self.name = name
self.chain = chain self.chain = chain
self.provider: Optional[BaseProvider] = None # set in _configure_provider self.provider: Optional[BaseProvider] = None # set in _configure_provider
@ -258,7 +262,12 @@ class RPCCondition(AccessControlCondition):
class ContractCondition(RPCCondition): class ContractCondition(RPCCondition):
CONDITION_TYPE = ConditionType.CONTRACT.value
class Schema(RPCCondition.Schema): class Schema(RPCCondition.Schema):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
)
standard_contract_type = fields.Str(required=False) standard_contract_type = fields.Str(required=False)
contract_address = fields.Str(required=True) contract_address = fields.Str(required=True)
function_abi = fields.Dict(required=False) function_abi = fields.Dict(required=False)
@ -279,6 +288,7 @@ class ContractCondition(RPCCondition):
def __init__( def __init__(
self, self,
contract_address: ChecksumAddress, contract_address: ChecksumAddress,
condition_type: str = CONDITION_TYPE,
standard_contract_type: Optional[str] = None, standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None, function_abi: Optional[ABIFunction] = None,
*args, *args,
@ -298,6 +308,7 @@ class ContractCondition(RPCCondition):
# spec # spec
self.contract_address = contract_address self.contract_address = contract_address
self.condition_type = condition_type
self.standard_contract_type = standard_contract_type self.standard_contract_type = standard_contract_type
self.function_abi = function_abi self.function_abi = function_abi
self.contract_function = self._get_unbound_contract_function() self.contract_function = self._get_unbound_contract_function()

View File

@ -1,6 +1,7 @@
import ast import ast
import base64 import base64
import operator as pyoperator import operator as pyoperator
from enum import Enum
from hashlib import md5 from hashlib import md5
from typing import Any, List, Optional, Tuple, Type from typing import Any, List, Optional, Tuple, Type
@ -57,13 +58,32 @@ class _ConditionField(fields.Dict):
# } # }
class ConditionType(Enum):
"""
Defines the types of conditions that can be evaluated.
"""
TIME = "time"
CONTRACT = "contract"
RPC = "rpc"
COMPOUND = "compound"
@classmethod
def values(cls) -> List[str]:
return [condition.value for condition in cls]
class CompoundAccessControlCondition(AccessControlCondition): class CompoundAccessControlCondition(AccessControlCondition):
AND_OPERATOR = "and" AND_OPERATOR = "and"
OR_OPERATOR = "or" OR_OPERATOR = "or"
OPERATORS = (AND_OPERATOR, OR_OPERATOR) OPERATORS = (AND_OPERATOR, OR_OPERATOR)
CONDITION_TYPE = ConditionType.COMPOUND.value
class Schema(CamelCaseSchema): class Schema(CamelCaseSchema):
SKIP_VALUES = (None,) SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False) name = fields.Str(required=False)
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"])) operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
operands = fields.List( operands = fields.List(
@ -82,6 +102,7 @@ class CompoundAccessControlCondition(AccessControlCondition):
self, self,
operator: str, operator: str,
operands: List[AccessControlCondition], operands: List[AccessControlCondition],
condition_type: str = CONDITION_TYPE,
name: Optional[str] = None, name: Optional[str] = None,
): ):
""" """
@ -90,10 +111,12 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*] "operands": [CONDITION*]
} }
""" """
self.condition_type = condition_type
if operator not in self.OPERATORS: if operator not in self.OPERATORS:
raise InvalidLogicalOperator(f"{operator} is not a valid operator") raise InvalidLogicalOperator(f"{operator} is not a valid operator")
self.operator = operator self.operator = operator
self.operands = operands self.operands = operands
self.condition_type = condition_type
self.name = name self.name = name
self.id = md5(bytes(self)).hexdigest()[:6] self.id = md5(bytes(self)).hexdigest()[:6]
@ -322,24 +345,19 @@ class ConditionLingo(_Serializable):
from nucypher.policy.conditions.time import TimeCondition from nucypher.policy.conditions.time import TimeCondition
# version logical adjustments can be made here as required # version logical adjustments can be made here as required
# Inspect
method = condition.get("method")
operator = condition.get("operator")
contract = condition.get("contractAddress")
# Resolve condition_type = condition.get("conditionType")
if method: for condition in (
if method == TimeCondition.METHOD: TimeCondition,
return TimeCondition ContractCondition,
elif contract: RPCCondition,
return ContractCondition CompoundAccessControlCondition,
elif method in RPCCondition.ALLOWED_METHODS: ):
return RPCCondition if condition.CONDITION_TYPE == condition_type:
elif operator: return condition
return CompoundAccessControlCondition
raise InvalidConditionLingo( raise InvalidConditionLingo(
f"Cannot resolve condition lingo type from data {condition}" f"Cannot resolve condition lingo with condition type {condition_type}"
) )
@classmethod @classmethod

View File

@ -1,18 +1,22 @@
from typing import Any, List, Optional from typing import Any, List, Optional
from marshmallow import fields, post_load from marshmallow import fields, post_load, validate
from nucypher.policy.conditions.evm import RPCCondition from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ReturnValueTest from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema from nucypher.policy.conditions.utils import CamelCaseSchema
class TimeCondition(RPCCondition): class TimeCondition(RPCCondition):
METHOD = "blocktime" METHOD = "blocktime"
CONDITION_TYPE = ConditionType.TIME.value
class Schema(CamelCaseSchema): class Schema(CamelCaseSchema):
SKIP_VALUES = (None,) SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.TIME.value), required=True
)
name = fields.Str(required=False) name = fields.Str(required=False)
chain = fields.Int(required=True) chain = fields.Int(required=True)
method = fields.Str(dump_default="blocktime", required=True) method = fields.Str(dump_default="blocktime", required=True)
@ -33,6 +37,7 @@ class TimeCondition(RPCCondition):
return_value_test: ReturnValueTest, return_value_test: ReturnValueTest,
chain: int, chain: int,
method: str = METHOD, method: str = METHOD,
condition_type: str = ConditionType.TIME.value,
name: Optional[str] = None, name: Optional[str] = None,
): ):
if method != self.METHOD: if method != self.METHOD:
@ -40,7 +45,11 @@ class TimeCondition(RPCCondition):
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method." f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
) )
super().__init__( super().__init__(
chain=chain, method=method, return_value_test=return_value_test, name=name chain=chain,
method=method,
return_value_test=return_value_test,
name=name,
condition_type=condition_type,
) )
def validate_method(self, method): def validate_method(self, method):

View File

@ -39,6 +39,7 @@ class _AccessControlCondition(TypedDict):
class RPCConditionDict(_AccessControlCondition): class RPCConditionDict(_AccessControlCondition):
conditionType: str
chain: int chain: int
method: str method: str
parameters: NotRequired[List[Any]] parameters: NotRequired[List[Any]]
@ -63,6 +64,7 @@ class ContractConditionDict(RPCConditionDict):
# #
# #
class CompoundConditionDict(TypedDict): class CompoundConditionDict(TypedDict):
conditionType: str
operator: Literal["and", "or"] operator: Literal["and", "or"]
operands: List["Lingo"] operands: List["Lingo"]

View File

@ -21,6 +21,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = { CONDITIONS = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,

View File

@ -521,14 +521,17 @@ def test_single_retrieve_with_onchain_conditions(enacted_policy, bob, ursulas):
conditions = { conditions = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "compound",
"operator": "and", "operator": "and",
"operands": [ "operands": [
{ {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
}, },
{ {
"conditionType": "rpc",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
"method": "eth_getBalance", "method": "eth_getBalance",
"parameters": [bob.checksum_address, "latest"], "parameters": [bob.checksum_address, "latest"],

View File

@ -1,8 +1,13 @@
{ {
"customABIMultipleParameters" : { "customABIMultipleParameters": {
"conditionType": "contract",
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7", "contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
"method": "isSubscribedToToken", "method": "isSubscribedToToken",
"parameters": [":userAddress", "subscriptionCode", 4], "parameters": [
":userAddress",
"subscriptionCode",
4
],
"functionAbi": { "functionAbi": {
"inputs": [ "inputs": [
{ {
@ -72,6 +77,7 @@
} }
}, },
"TStaking": { "TStaking": {
"conditionType": "contract",
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7", "contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
"chain": 1, "chain": 1,
"method": "stakes", "method": "stakes",
@ -114,6 +120,7 @@
} }
}, },
"SubscriptionManagerPayment": { "SubscriptionManagerPayment": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"chain": 137, "chain": 137,
"method": "isValidPolicy", "method": "isValidPolicy",
@ -126,6 +133,7 @@
} }
}, },
"ERC1155_balance": { "ERC1155_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC1155", "standardContractType": "ERC1155",
"chain": 1, "chain": 1,
@ -140,13 +148,24 @@
} }
}, },
"ERC1155_balance_batch": { "ERC1155_balance_batch": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC1155", "standardContractType": "ERC1155",
"chain": 1, "chain": 1,
"method": "balanceOfBatch", "method": "balanceOfBatch",
"parameters": [ "parameters": [
[":userAddress",":userAddress",":userAddress",":userAddress"], [
[1,2,10003,10004] ":userAddress",
":userAddress",
":userAddress",
":userAddress"
],
[
1,
2,
10003,
10004
]
], ],
"returnValueTest": { "returnValueTest": {
"comparator": ">", "comparator": ">",
@ -154,6 +173,7 @@
} }
}, },
"ERC721_ownership": { "ERC721_ownership": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC721", "standardContractType": "ERC721",
"chain": 1, "chain": 1,
@ -167,6 +187,7 @@
} }
}, },
"ERC721_balance": { "ERC721_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC721", "standardContractType": "ERC721",
"chain": 1, "chain": 1,
@ -180,6 +201,7 @@
} }
}, },
"ERC20_balance": { "ERC20_balance": {
"conditionType": "contract",
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118", "contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"standardContractType": "ERC20", "standardContractType": "ERC20",
"chain": 1, "chain": 1,
@ -193,6 +215,7 @@
} }
}, },
"ETH_balance": { "ETH_balance": {
"conditionType": "contract",
"contractAddress": "", "contractAddress": "",
"standardContractType": "", "standardContractType": "",
"chain": 1, "chain": 1,
@ -207,6 +230,7 @@
} }
}, },
"specific_wallet_address": { "specific_wallet_address": {
"conditionType": "contract",
"contractAddress": "", "contractAddress": "",
"standardContractType": "", "standardContractType": "",
"chain": 1, "chain": 1,
@ -220,11 +244,14 @@
} }
}, },
"timestamp": { "timestamp": {
"conditionType": "contract",
"contractAddress": "", "contractAddress": "",
"standardContractType": "timestamp", "standardContractType": "timestamp",
"chain": 1, "chain": 1,
"method": "eth_getBlockByNumber", "method": "eth_getBlockByNumber",
"parameters": ["latest"], "parameters": [
"latest"
],
"returnValueTest": { "returnValueTest": {
"comparator": ">=", "comparator": ">=",
"value": 1234567890 "value": 1234567890

View File

@ -612,14 +612,17 @@ def compound_blocktime_lingo():
return { return {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "compound",
"operator": "and", "operator": "and",
"operands": [ "operands": [
{ {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
}, },
{ {
"conditionType": "time",
"returnValueTest": { "returnValueTest": {
"value": "99999999999999999", "value": "99999999999999999",
"comparator": "<", "comparator": "<",
@ -628,6 +631,7 @@ def compound_blocktime_lingo():
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
}, },
{ {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,

View File

@ -22,6 +22,7 @@ PLAINTEXT = "peace at dawn"
CONDITIONS = { CONDITIONS = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,

View File

@ -28,14 +28,17 @@ def test_single_retrieve_with_truthy_conditions(enacted_policy, bob, ursulas, mo
conditions = { conditions = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "compound",
"operator": "and", "operator": "and",
"operands": [ "operands": [
{ {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
}, },
{ {
"conditionType": "time",
"returnValueTest": {"value": 99999999999999999, "comparator": "<"}, "returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
@ -69,6 +72,7 @@ def test_single_retrieve_with_falsy_conditions(enacted_policy, bob, ursulas, moc
{ {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
@ -130,6 +134,7 @@ def test_middleware_handling_of_failed_condition_responses(
{ {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,

View File

@ -29,6 +29,7 @@ def _attempt_decryption(BobClass, plaintext):
definitely_false_condition = { definitely_false_condition = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
"method": "blocktime", "method": "blocktime",
"returnValueTest": {"comparator": "<", "value": 0}, "returnValueTest": {"comparator": "<", "value": 0},

View File

@ -182,6 +182,7 @@ def test_ritualist(temp_dir_path, testerchain, dkg_public_key):
CONDITIONS = { CONDITIONS = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": "0", "comparator": ">"}, "returnValueTest": {"value": "0", "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,

View File

@ -7,23 +7,36 @@ import nucypher
from nucypher.blockchain.eth.constants import NULL_ADDRESS from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
from nucypher.policy.conditions.exceptions import InvalidConditionLingo from nucypher.policy.conditions.exceptions import InvalidConditionLingo
from nucypher.policy.conditions.lingo import ConditionLingo from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
from tests.constants import TESTERCHAIN_CHAIN_ID from tests.constants import TESTERCHAIN_CHAIN_ID
@pytest.fixture(scope='module') @pytest.fixture(scope="module")
def lingo(): def lingo_with_condition():
return {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID,
}
@pytest.fixture(scope="module")
def lingo_with_compound_condition():
return { return {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "compound",
"operator": "and", "operator": "and",
"operands": [ "operands": [
{ {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
}, },
{ {
"conditionType": "time",
"returnValueTest": {"value": 99999999999999999, "comparator": "<"}, "returnValueTest": {"value": 99999999999999999, "comparator": "<"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
@ -55,9 +68,11 @@ def test_invalid_condition():
invalid_operator_position_lingo = { invalid_operator_position_lingo = {
"version": ConditionLingo.VERSION, "version": ConditionLingo.VERSION,
"condition": { "condition": {
"conditionType": "compound",
"operator": "and", "operator": "and",
"operands": [ "operands": [
{ {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
@ -87,6 +102,7 @@ def test_invalid_condition_version(case):
lingo_dict = { lingo_dict = {
"version": newer_version_string, "version": newer_version_string,
"condition": { "condition": {
"conditionType": "time",
"returnValueTest": {"value": 0, "comparator": ">"}, "returnValueTest": {"value": 0, "comparator": ">"},
"method": "blocktime", "method": "blocktime",
"chain": TESTERCHAIN_CHAIN_ID, "chain": TESTERCHAIN_CHAIN_ID,
@ -102,24 +118,24 @@ def test_invalid_condition_version(case):
_ = ConditionLingo.from_dict(lingo_dict) _ = ConditionLingo.from_dict(lingo_dict)
def test_condition_lingo_to_from_dict(lingo): def test_condition_lingo_to_from_dict(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo) clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
clingo_dict = clingo.to_dict() clingo_dict = clingo.to_dict()
assert clingo_dict == lingo assert clingo_dict == lingo_with_compound_condition
def test_condition_lingo_to_from_json(lingo): def test_condition_lingo_to_from_json(lingo_with_compound_condition):
# A bit more convoluted because fields aren't # A bit more convoluted because fields aren't
# necessarily ordered - so string comparison is tricky # necessarily ordered - so string comparison is tricky
clingo_from_dict = ConditionLingo.from_dict(lingo) clingo_from_dict = ConditionLingo.from_dict(lingo_with_compound_condition)
lingo_json = clingo_from_dict.to_json() lingo_json = clingo_from_dict.to_json()
clingo_from_json = ConditionLingo.from_json(lingo_json) clingo_from_json = ConditionLingo.from_json(lingo_json)
assert clingo_from_json.to_dict() == lingo assert clingo_from_json.to_dict() == lingo_with_compound_condition
def test_condition_lingo_repr(lingo): def test_condition_lingo_repr(lingo_with_compound_condition):
clingo = ConditionLingo.from_dict(lingo) clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
clingo_string = f"{clingo}" clingo_string = f"{clingo}"
assert f"{clingo.__class__.__name__}" in clingo_string assert f"{clingo.__class__.__name__}" in clingo_string
assert f"version={ConditionLingo.VERSION}" in clingo_string assert f"version={ConditionLingo.VERSION}" in clingo_string
@ -144,3 +160,16 @@ def test_lingo_parameter_int_type_preservation(custom_abi_with_multiple_paramete
clingo = ConditionLingo.from_json(clingo_json) clingo = ConditionLingo.from_json(clingo_json)
conditions = clingo.to_dict() conditions = clingo.to_dict()
assert conditions["condition"]["parameters"][2] == 4 assert conditions["condition"]["parameters"][2] == 4
def test_lingo_resolves_condition_type(lingo_with_condition):
for condition_type in ConditionType.values():
lingo_with_condition["conditionType"] = condition_type
ConditionLingo.resolve_condition_class(lingo_with_condition)
def test_lingo_rejects_invalid_condition_type(lingo_with_condition):
for condition_type in ["invalid", "", None]:
lingo_with_condition["conditionType"] = condition_type
with pytest.raises(InvalidConditionLingo):
ConditionLingo.resolve_condition_class(lingo_with_condition)