mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3201 from nucypher/condition-type
Add `condition_type` to condition schemaspull/3204/head
commit
f6ee932dac
|
@ -0,0 +1 @@
|
|||
Add a mandatory condition_type field to condition schemas
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
|||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
from marshmallow import fields, post_load, validates_schema
|
||||
from marshmallow import fields, post_load, validate, validates_schema
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
@ -18,7 +18,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
NoConnectionToChain,
|
||||
RPCExecutionFailed,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ReturnValueTest
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema, camel_case_to_snake
|
||||
|
||||
# TODO: Move this to a more appropriate location,
|
||||
|
@ -108,17 +108,19 @@ def _validate_chain(chain: int) -> None:
|
|||
|
||||
class RPCCondition(AccessControlCondition):
|
||||
ETH_PREFIX = "eth_"
|
||||
|
||||
ALLOWED_METHODS = (
|
||||
# RPC
|
||||
"eth_getBalance",
|
||||
) # TODO other allowed methods (tDEC #64)
|
||||
|
||||
LOG = logging.Logger(__name__)
|
||||
CONDITION_TYPE = ConditionType.RPC.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
name = fields.Str(required=False)
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.RPC.value), required=True
|
||||
)
|
||||
chain = fields.Int(required=True)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
|
@ -139,6 +141,7 @@ class RPCCondition(AccessControlCondition):
|
|||
chain: int,
|
||||
method: str,
|
||||
return_value_test: ReturnValueTest,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[List[Any]] = None,
|
||||
):
|
||||
|
@ -147,6 +150,7 @@ class RPCCondition(AccessControlCondition):
|
|||
_validate_chain(chain=chain)
|
||||
|
||||
# internal
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
self.chain = chain
|
||||
self.provider: Optional[BaseProvider] = None # set in _configure_provider
|
||||
|
@ -258,7 +262,12 @@ class RPCCondition(AccessControlCondition):
|
|||
|
||||
|
||||
class ContractCondition(RPCCondition):
|
||||
CONDITION_TYPE = ConditionType.CONTRACT.value
|
||||
|
||||
class Schema(RPCCondition.Schema):
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.CONTRACT.value), required=True
|
||||
)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
contract_address = fields.Str(required=True)
|
||||
function_abi = fields.Dict(required=False)
|
||||
|
@ -279,6 +288,7 @@ class ContractCondition(RPCCondition):
|
|||
def __init__(
|
||||
self,
|
||||
contract_address: ChecksumAddress,
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
|
@ -298,6 +308,7 @@ class ContractCondition(RPCCondition):
|
|||
|
||||
# spec
|
||||
self.contract_address = contract_address
|
||||
self.condition_type = condition_type
|
||||
self.standard_contract_type = standard_contract_type
|
||||
self.function_abi = function_abi
|
||||
self.contract_function = self._get_unbound_contract_function()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
import base64
|
||||
import operator as pyoperator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
|
@ -57,13 +58,32 @@ class _ConditionField(fields.Dict):
|
|||
# }
|
||||
|
||||
|
||||
class ConditionType(Enum):
|
||||
"""
|
||||
Defines the types of conditions that can be evaluated.
|
||||
"""
|
||||
|
||||
TIME = "time"
|
||||
CONTRACT = "contract"
|
||||
RPC = "rpc"
|
||||
COMPOUND = "compound"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> List[str]:
|
||||
return [condition.value for condition in cls]
|
||||
|
||||
|
||||
class CompoundAccessControlCondition(AccessControlCondition):
|
||||
AND_OPERATOR = "and"
|
||||
OR_OPERATOR = "or"
|
||||
OPERATORS = (AND_OPERATOR, OR_OPERATOR)
|
||||
CONDITION_TYPE = ConditionType.COMPOUND.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
|
||||
operands = fields.List(
|
||||
|
@ -82,6 +102,7 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
self,
|
||||
operator: str,
|
||||
operands: List[AccessControlCondition],
|
||||
condition_type: str = CONDITION_TYPE,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -90,10 +111,12 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
self.condition_type = condition_type
|
||||
if operator not in self.OPERATORS:
|
||||
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
|
||||
self.operator = operator
|
||||
self.operands = operands
|
||||
self.condition_type = condition_type
|
||||
self.name = name
|
||||
self.id = md5(bytes(self)).hexdigest()[:6]
|
||||
|
||||
|
@ -322,24 +345,19 @@ class ConditionLingo(_Serializable):
|
|||
from nucypher.policy.conditions.time import TimeCondition
|
||||
|
||||
# version logical adjustments can be made here as required
|
||||
# Inspect
|
||||
method = condition.get("method")
|
||||
operator = condition.get("operator")
|
||||
contract = condition.get("contractAddress")
|
||||
|
||||
# Resolve
|
||||
if method:
|
||||
if method == TimeCondition.METHOD:
|
||||
return TimeCondition
|
||||
elif contract:
|
||||
return ContractCondition
|
||||
elif method in RPCCondition.ALLOWED_METHODS:
|
||||
return RPCCondition
|
||||
elif operator:
|
||||
return CompoundAccessControlCondition
|
||||
condition_type = condition.get("conditionType")
|
||||
for condition in (
|
||||
TimeCondition,
|
||||
ContractCondition,
|
||||
RPCCondition,
|
||||
CompoundAccessControlCondition,
|
||||
):
|
||||
if condition.CONDITION_TYPE == condition_type:
|
||||
return condition
|
||||
|
||||
raise InvalidConditionLingo(
|
||||
f"Cannot resolve condition lingo type from data {condition}"
|
||||
f"Cannot resolve condition lingo with condition type {condition_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from marshmallow import fields, post_load
|
||||
from marshmallow import fields, post_load, validate
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import ReturnValueTest
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
|
||||
|
||||
class TimeCondition(RPCCondition):
|
||||
METHOD = "blocktime"
|
||||
CONDITION_TYPE = ConditionType.TIME.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.TIME.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
chain = fields.Int(required=True)
|
||||
method = fields.Str(dump_default="blocktime", required=True)
|
||||
|
@ -33,6 +37,7 @@ class TimeCondition(RPCCondition):
|
|||
return_value_test: ReturnValueTest,
|
||||
chain: int,
|
||||
method: str = METHOD,
|
||||
condition_type: str = ConditionType.TIME.value,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
if method != self.METHOD:
|
||||
|
@ -40,7 +45,11 @@ class TimeCondition(RPCCondition):
|
|||
f"{self.__class__.__name__} must be instantiated with the {self.METHOD} method."
|
||||
)
|
||||
super().__init__(
|
||||
chain=chain, method=method, return_value_test=return_value_test, name=name
|
||||
chain=chain,
|
||||
method=method,
|
||||
return_value_test=return_value_test,
|
||||
name=name,
|
||||
condition_type=condition_type,
|
||||
)
|
||||
|
||||
def validate_method(self, method):
|
||||
|
|
|
@ -39,6 +39,7 @@ class _AccessControlCondition(TypedDict):
|
|||
|
||||
|
||||
class RPCConditionDict(_AccessControlCondition):
|
||||
conditionType: str
|
||||
chain: int
|
||||
method: str
|
||||
parameters: NotRequired[List[Any]]
|
||||
|
@ -63,6 +64,7 @@ class ContractConditionDict(RPCConditionDict):
|
|||
#
|
||||
#
|
||||
class CompoundConditionDict(TypedDict):
|
||||
conditionType: str
|
||||
operator: Literal["and", "or"]
|
||||
operands: List["Lingo"]
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ PLAINTEXT = "peace at dawn"
|
|||
CONDITIONS = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
|
|
@ -521,14 +521,17 @@ def test_single_retrieve_with_onchain_conditions(enacted_policy, bob, ursulas):
|
|||
conditions = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "compound",
|
||||
"operator": "and",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
{
|
||||
"conditionType": "rpc",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "eth_getBalance",
|
||||
"parameters": [bob.checksum_address, "latest"],
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
{
|
||||
"customABIMultipleParameters" : {
|
||||
"customABIMultipleParameters": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
|
||||
"method": "isSubscribedToToken",
|
||||
"parameters": [":userAddress", "subscriptionCode", 4],
|
||||
"parameters": [
|
||||
":userAddress",
|
||||
"subscriptionCode",
|
||||
4
|
||||
],
|
||||
"functionAbi": {
|
||||
"inputs": [
|
||||
{
|
||||
|
@ -72,6 +77,7 @@
|
|||
}
|
||||
},
|
||||
"TStaking": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
|
||||
"chain": 1,
|
||||
"method": "stakes",
|
||||
|
@ -114,6 +120,7 @@
|
|||
}
|
||||
},
|
||||
"SubscriptionManagerPayment": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"chain": 137,
|
||||
"method": "isValidPolicy",
|
||||
|
@ -126,6 +133,7 @@
|
|||
}
|
||||
},
|
||||
"ERC1155_balance": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"standardContractType": "ERC1155",
|
||||
"chain": 1,
|
||||
|
@ -140,20 +148,32 @@
|
|||
}
|
||||
},
|
||||
"ERC1155_balance_batch": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"standardContractType": "ERC1155",
|
||||
"chain": 1,
|
||||
"method": "balanceOfBatch",
|
||||
"parameters": [
|
||||
[":userAddress",":userAddress",":userAddress",":userAddress"],
|
||||
[1,2,10003,10004]
|
||||
[
|
||||
":userAddress",
|
||||
":userAddress",
|
||||
":userAddress",
|
||||
":userAddress"
|
||||
],
|
||||
[
|
||||
1,
|
||||
2,
|
||||
10003,
|
||||
10004
|
||||
]
|
||||
],
|
||||
"returnValueTest": {
|
||||
"comparator": ">",
|
||||
"value": 0
|
||||
}
|
||||
},
|
||||
"ERC721_ownership": {
|
||||
"ERC721_ownership": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"standardContractType": "ERC721",
|
||||
"chain": 1,
|
||||
|
@ -166,7 +186,8 @@
|
|||
"value": ":userAddress"
|
||||
}
|
||||
},
|
||||
"ERC721_balance": {
|
||||
"ERC721_balance": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"standardContractType": "ERC721",
|
||||
"chain": 1,
|
||||
|
@ -179,7 +200,8 @@
|
|||
"value": 0
|
||||
}
|
||||
},
|
||||
"ERC20_balance": {
|
||||
"ERC20_balance": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"standardContractType": "ERC20",
|
||||
"chain": 1,
|
||||
|
@ -192,7 +214,8 @@
|
|||
"value": 0
|
||||
}
|
||||
},
|
||||
"ETH_balance": {
|
||||
"ETH_balance": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "",
|
||||
"standardContractType": "",
|
||||
"chain": 1,
|
||||
|
@ -206,7 +229,8 @@
|
|||
"value": 10000000000000
|
||||
}
|
||||
},
|
||||
"specific_wallet_address": {
|
||||
"specific_wallet_address": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "",
|
||||
"standardContractType": "",
|
||||
"chain": 1,
|
||||
|
@ -219,12 +243,15 @@
|
|||
"value": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118"
|
||||
}
|
||||
},
|
||||
"timestamp": {
|
||||
"timestamp": {
|
||||
"conditionType": "contract",
|
||||
"contractAddress": "",
|
||||
"standardContractType": "timestamp",
|
||||
"chain": 1,
|
||||
"method": "eth_getBlockByNumber",
|
||||
"parameters": ["latest"],
|
||||
"parameters": [
|
||||
"latest"
|
||||
],
|
||||
"returnValueTest": {
|
||||
"comparator": ">=",
|
||||
"value": 1234567890
|
||||
|
|
|
@ -612,14 +612,17 @@ def compound_blocktime_lingo():
|
|||
return {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "compound",
|
||||
"operator": "and",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {
|
||||
"value": "99999999999999999",
|
||||
"comparator": "<",
|
||||
|
@ -628,6 +631,7 @@ def compound_blocktime_lingo():
|
|||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
|
|
@ -22,6 +22,7 @@ PLAINTEXT = "peace at dawn"
|
|||
CONDITIONS = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
|
|
@ -28,14 +28,17 @@ def test_single_retrieve_with_truthy_conditions(enacted_policy, bob, ursulas, mo
|
|||
conditions = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "compound",
|
||||
"operator": "and",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
@ -69,6 +72,7 @@ def test_single_retrieve_with_falsy_conditions(enacted_policy, bob, ursulas, moc
|
|||
{
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
@ -130,6 +134,7 @@ def test_middleware_handling_of_failed_condition_responses(
|
|||
{
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
|
|
@ -29,6 +29,7 @@ def _attempt_decryption(BobClass, plaintext):
|
|||
definitely_false_condition = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
"method": "blocktime",
|
||||
"returnValueTest": {"comparator": "<", "value": 0},
|
||||
|
|
|
@ -182,6 +182,7 @@ def test_ritualist(temp_dir_path, testerchain, dkg_public_key):
|
|||
CONDITIONS = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": "0", "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
|
|
@ -7,23 +7,36 @@ import nucypher
|
|||
from nucypher.blockchain.eth.constants import NULL_ADDRESS
|
||||
from nucypher.policy.conditions.context import USER_ADDRESS_CONTEXT
|
||||
from nucypher.policy.conditions.exceptions import InvalidConditionLingo
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo, ConditionType
|
||||
from tests.constants import TESTERCHAIN_CHAIN_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def lingo():
|
||||
@pytest.fixture(scope="module")
|
||||
def lingo_with_condition():
|
||||
return {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lingo_with_compound_condition():
|
||||
return {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "compound",
|
||||
"operator": "and",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
},
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 99999999999999999, "comparator": "<"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
@ -55,9 +68,11 @@ def test_invalid_condition():
|
|||
invalid_operator_position_lingo = {
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {
|
||||
"conditionType": "compound",
|
||||
"operator": "and",
|
||||
"operands": [
|
||||
{
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
@ -87,6 +102,7 @@ def test_invalid_condition_version(case):
|
|||
lingo_dict = {
|
||||
"version": newer_version_string,
|
||||
"condition": {
|
||||
"conditionType": "time",
|
||||
"returnValueTest": {"value": 0, "comparator": ">"},
|
||||
"method": "blocktime",
|
||||
"chain": TESTERCHAIN_CHAIN_ID,
|
||||
|
@ -102,24 +118,24 @@ def test_invalid_condition_version(case):
|
|||
_ = ConditionLingo.from_dict(lingo_dict)
|
||||
|
||||
|
||||
def test_condition_lingo_to_from_dict(lingo):
|
||||
clingo = ConditionLingo.from_dict(lingo)
|
||||
def test_condition_lingo_to_from_dict(lingo_with_compound_condition):
|
||||
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
|
||||
clingo_dict = clingo.to_dict()
|
||||
assert clingo_dict == lingo
|
||||
assert clingo_dict == lingo_with_compound_condition
|
||||
|
||||
|
||||
def test_condition_lingo_to_from_json(lingo):
|
||||
def test_condition_lingo_to_from_json(lingo_with_compound_condition):
|
||||
# A bit more convoluted because fields aren't
|
||||
# necessarily ordered - so string comparison is tricky
|
||||
clingo_from_dict = ConditionLingo.from_dict(lingo)
|
||||
clingo_from_dict = ConditionLingo.from_dict(lingo_with_compound_condition)
|
||||
lingo_json = clingo_from_dict.to_json()
|
||||
|
||||
clingo_from_json = ConditionLingo.from_json(lingo_json)
|
||||
assert clingo_from_json.to_dict() == lingo
|
||||
assert clingo_from_json.to_dict() == lingo_with_compound_condition
|
||||
|
||||
|
||||
def test_condition_lingo_repr(lingo):
|
||||
clingo = ConditionLingo.from_dict(lingo)
|
||||
def test_condition_lingo_repr(lingo_with_compound_condition):
|
||||
clingo = ConditionLingo.from_dict(lingo_with_compound_condition)
|
||||
clingo_string = f"{clingo}"
|
||||
assert f"{clingo.__class__.__name__}" in clingo_string
|
||||
assert f"version={ConditionLingo.VERSION}" in clingo_string
|
||||
|
@ -144,3 +160,16 @@ def test_lingo_parameter_int_type_preservation(custom_abi_with_multiple_paramete
|
|||
clingo = ConditionLingo.from_json(clingo_json)
|
||||
conditions = clingo.to_dict()
|
||||
assert conditions["condition"]["parameters"][2] == 4
|
||||
|
||||
|
||||
def test_lingo_resolves_condition_type(lingo_with_condition):
|
||||
for condition_type in ConditionType.values():
|
||||
lingo_with_condition["conditionType"] = condition_type
|
||||
ConditionLingo.resolve_condition_class(lingo_with_condition)
|
||||
|
||||
|
||||
def test_lingo_rejects_invalid_condition_type(lingo_with_condition):
|
||||
for condition_type in ["invalid", "", None]:
|
||||
lingo_with_condition["conditionType"] = condition_type
|
||||
with pytest.raises(InvalidConditionLingo):
|
||||
ConditionLingo.resolve_condition_class(lingo_with_condition)
|
||||
|
|
Loading…
Reference in New Issue