mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3002 from theref/key-value
Add custom abi conditions and key/value return valuespull/3013/head
commit
131a7d4ced
|
@ -0,0 +1 @@
|
|||
Allow a key to be specified for evaluating the return value
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Optional, Tuple, Type, Union, NamedTuple
|
||||
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from web3.providers import BaseProvider
|
||||
|
@ -30,6 +30,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidContextVariableData,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
@ -129,6 +130,11 @@ def evaluate_conditions(
|
|||
if not result:
|
||||
# explicit condition failure
|
||||
error = ("Decryption conditions not satisfied", HTTPStatus.FORBIDDEN)
|
||||
except ReturnValueEvaluationError as e:
|
||||
error = (
|
||||
f"Unable to evaluate return value: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except InvalidCondition as e:
|
||||
error = (
|
||||
f"Incorrect value provided for condition: {e}",
|
||||
|
|
|
@ -108,7 +108,7 @@ def get_context_value(context_variable: str, **context) -> Any:
|
|||
# fallback for context variable without directive - assume key,value pair
|
||||
# handles the case for user customized context variables
|
||||
value = context.get(context_variable)
|
||||
if not value:
|
||||
if value is None:
|
||||
raise RequiredContextVariable(
|
||||
f'"No value provided for unrecognized context variable "{context_variable}"'
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ from marshmallow import fields, post_load
|
|||
from web3 import Web3
|
||||
from web3.contract import ContractFunction
|
||||
from web3.providers import BaseProvider
|
||||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
|
||||
from nucypher.policy.conditions._utils import CamelCaseSchema
|
||||
|
@ -45,28 +46,39 @@ _CONDITION_CHAINS = (
|
|||
)
|
||||
|
||||
|
||||
def _resolve_abi(standard_contract_type: str, method: str, function_abi: List) -> List:
|
||||
def _resolve_abi(
|
||||
w3: Web3,
|
||||
method: str,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
) -> ABIFunction:
|
||||
"""Resolves the contract an/or function ABI from a standard contract name"""
|
||||
|
||||
if not (function_abi or standard_contract_type):
|
||||
# TODO: Is this protection needed?
|
||||
raise InvalidCondition(
|
||||
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
|
||||
)
|
||||
|
||||
if standard_contract_type:
|
||||
try:
|
||||
function_abi = STANDARD_ABIS[standard_contract_type]
|
||||
# Lookup the standard ABI given it's ERC standard name (standard contract type)
|
||||
contract_abi = STANDARD_ABIS[standard_contract_type]
|
||||
except KeyError:
|
||||
raise InvalidCondition(
|
||||
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract all function ABIs from the contract's ABI.
|
||||
# Will raise a ValueError if there is not exactly one match.
|
||||
function_abi = w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
|
||||
except ValueError as e:
|
||||
raise InvalidCondition(str(e))
|
||||
|
||||
if not function_abi:
|
||||
raise InvalidCondition(f"No function ABI supplied for '{method}'")
|
||||
|
||||
# TODO: Verify that the function and ABI pair match?
|
||||
# ABI(function_abi)
|
||||
return function_abi
|
||||
return ABIFunction(function_abi)
|
||||
|
||||
|
||||
def camel_case_to_snake(data: str) -> str:
|
||||
|
@ -86,14 +98,19 @@ def _resolve_any_context_variables(
|
|||
processed_parameters.append(p)
|
||||
|
||||
v = return_value_test.value
|
||||
k = return_value_test.key
|
||||
if is_context_variable(return_value_test.value):
|
||||
v = get_context_value(context_variable=v, **context)
|
||||
processed_return_value_test = ReturnValueTest(return_value_test.comparator, value=v)
|
||||
if is_context_variable(return_value_test.key):
|
||||
k = get_context_value(context_variable=k, **context)
|
||||
processed_return_value_test = ReturnValueTest(
|
||||
return_value_test.comparator, value=v, key=k
|
||||
)
|
||||
|
||||
return processed_parameters, processed_return_value_test
|
||||
|
||||
|
||||
def _validate_chain(chain: int):
|
||||
def _validate_chain(chain: int) -> None:
|
||||
if not isinstance(chain, int):
|
||||
raise ValueError(f'"The chain" field of c a condition must be the '
|
||||
f'integer of a chain ID (got "{chain}").')
|
||||
|
@ -216,17 +233,20 @@ class ContractCondition(RPCCondition):
|
|||
SKIP_VALUES = (None,)
|
||||
standard_contract_type = fields.Str(required=False)
|
||||
contract_address = fields.Str(required=True)
|
||||
function_abi = fields.Str(required=False)
|
||||
function_abi = fields.Dict(required=False)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return ContractCondition(**data)
|
||||
|
||||
def __init__(self,
|
||||
contract_address: ChecksumAddress,
|
||||
standard_contract_type: str = None,
|
||||
function_abi: List = None,
|
||||
*args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
contract_address: ChecksumAddress,
|
||||
standard_contract_type: Optional[str] = None,
|
||||
function_abi: Optional[ABIFunction] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
# internal
|
||||
super().__init__(*args, **kwargs)
|
||||
self.w3 = Web3() # used to instantiate contract function without a provider
|
||||
|
@ -234,6 +254,7 @@ class ContractCondition(RPCCondition):
|
|||
# preprocessing
|
||||
contract_address = to_checksum_address(contract_address)
|
||||
function_abi = _resolve_abi(
|
||||
w3=self.w3,
|
||||
standard_contract_type=standard_contract_type,
|
||||
method=self.method,
|
||||
function_abi=function_abi
|
||||
|
@ -262,7 +283,7 @@ class ContractCondition(RPCCondition):
|
|||
"""Gets an unbound contract function to evaluate for this condition"""
|
||||
try:
|
||||
contract = self.w3.eth.contract(
|
||||
address=self.contract_address, abi=self.function_abi
|
||||
address=self.contract_address, abi=[self.function_abi]
|
||||
)
|
||||
contract_function = getattr(contract.functions, self.method)
|
||||
return contract_function
|
||||
|
|
|
@ -26,6 +26,9 @@ class NoConnectionToChain(RuntimeError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class ReturnValueEvaluationError(Exception):
|
||||
"""Issue with Return Value and Key"""
|
||||
|
||||
# Context Variable
|
||||
class RequiredContextVariable(Exception):
|
||||
"""No value provided for context variable"""
|
||||
|
|
|
@ -21,7 +21,7 @@ import base64
|
|||
import json
|
||||
import operator as pyoperator
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, Iterator, List, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from marshmallow import fields, post_load
|
||||
|
||||
|
@ -31,6 +31,7 @@ from nucypher.policy.conditions._utils import (
|
|||
)
|
||||
from nucypher.policy.conditions.base import ReencryptionCondition
|
||||
from nucypher.policy.conditions.context import is_context_variable
|
||||
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
|
||||
|
||||
|
||||
class Operator:
|
||||
|
@ -83,19 +84,26 @@ class ReturnValueTest:
|
|||
COMPARATORS = tuple(_COMPARATOR_FUNCTIONS)
|
||||
|
||||
class ReturnValueTestSchema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
comparator = fields.Str()
|
||||
value = fields.Raw(allow_none=False) # any valid type (excludes None)
|
||||
key = fields.Raw(allow_none=True)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return ReturnValueTest(**data)
|
||||
|
||||
def __init__(self, comparator: str, value: Any):
|
||||
def __init__(self, comparator: str, value: Any, key: Optional[Union[int, str]] = None):
|
||||
if comparator not in self.COMPARATORS:
|
||||
raise self.InvalidExpression(
|
||||
f'"{comparator}" is not a permitted comparator.'
|
||||
)
|
||||
|
||||
if not isinstance(key, (int, str)) and key is not None:
|
||||
raise self.InvalidExpression(
|
||||
f'"{key}" is not a permitted key. Must be a string or integer.'
|
||||
)
|
||||
|
||||
if not is_context_variable(value):
|
||||
# verify that value is valid, but don't set it here so as not to change the value;
|
||||
# it will be sanitized at eval time. Need to maintain serialization/deserialization
|
||||
|
@ -104,6 +112,7 @@ class ReturnValueTest:
|
|||
|
||||
self.comparator = comparator
|
||||
self.value = value
|
||||
self.key = key
|
||||
|
||||
def _sanitize_value(self, value):
|
||||
try:
|
||||
|
@ -111,14 +120,44 @@ class ReturnValueTest:
|
|||
except Exception:
|
||||
raise self.InvalidExpression(f'"{value}" is not a permitted value.')
|
||||
|
||||
def _process_data(self, data: Any) -> Any:
|
||||
"""
|
||||
If a key is specified, return the value at that key in the data if data is a dict or list-like.
|
||||
Otherwise, return the data.
|
||||
"""
|
||||
processed_data = data
|
||||
if self.key is not None:
|
||||
if isinstance(data, dict):
|
||||
try:
|
||||
processed_data = data[self.key]
|
||||
except KeyError:
|
||||
raise ReturnValueEvaluationError(
|
||||
f"Key '{self.key}' not found in return data."
|
||||
)
|
||||
elif isinstance(self.key, int) and isinstance(data, (list, tuple)):
|
||||
try:
|
||||
processed_data = data[self.key]
|
||||
except IndexError:
|
||||
raise ReturnValueEvaluationError(
|
||||
f"Index '{self.key}' not found in return data."
|
||||
)
|
||||
else:
|
||||
raise ReturnValueEvaluationError(
|
||||
f"Key: {self.key} and Value: {data} are not compatible types."
|
||||
)
|
||||
|
||||
return processed_data
|
||||
|
||||
def eval(self, data) -> bool:
|
||||
if is_context_variable(self.value):
|
||||
if is_context_variable(self.value) or is_context_variable(self.key):
|
||||
# programming error if we get here
|
||||
raise RuntimeError(
|
||||
f"'{self.value}' is an unprocessed context variable and is not valid "
|
||||
f"Return value comparator contains an unprocessed context variable (key={self.key}, value={self.value}) and is not valid "
|
||||
f"for condition evaluation."
|
||||
)
|
||||
left_operand = self._sanitize_value(data)
|
||||
|
||||
processed_data = self._process_data(data)
|
||||
left_operand = self._sanitize_value(processed_data)
|
||||
right_operand = self._sanitize_value(self.value)
|
||||
result = self._COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
|
||||
return result
|
||||
|
|
|
@ -49,7 +49,9 @@ with open(VECTORS_FILE, 'r') as file:
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_condition_blockchains(mocker):
|
||||
"""adds testerchain to permitted conditional chains"""
|
||||
mocker.patch.object(nucypher.policy.conditions.evm, '_CONDITION_CHAINS', tuple([131277322940537]))
|
||||
mocker.patch.object(
|
||||
nucypher.policy.conditions.evm, "_CONDITION_CHAINS", tuple([131277322940537])
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -106,7 +108,9 @@ def erc20_evm_condition(test_registry, agency):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_context_variable_erc20_condition(test_registry, agency, testerchain):
|
||||
def custom_context_variable_erc20_condition(
|
||||
test_registry, agency, testerchain, mock_condition_blockchains
|
||||
):
|
||||
token = ContractAgency.get_agent(NucypherTokenAgent, registry=test_registry)
|
||||
condition = ContractCondition(
|
||||
contract_address=token.contract.address,
|
||||
|
@ -181,7 +185,7 @@ def subscription_manager_is_active_policy_condition(test_registry, agency):
|
|||
)
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=[subscription_manager.contract.find_functions_by_name('isPolicyActive')[0].abi],
|
||||
function_abi=subscription_manager.contract.get_function_by_name("isPolicyActive").abi,
|
||||
method="isPolicyActive",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest("==", True),
|
||||
|
@ -199,7 +203,7 @@ def subscription_manager_get_policy_zeroized_policy_struct_condition(
|
|||
)
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=[subscription_manager.contract.find_functions_by_name('getPolicy')[0].abi],
|
||||
function_abi=subscription_manager.contract.get_function_by_name("getPolicy").abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest("==", ":expectedPolicyStruct"),
|
||||
|
|
|
@ -24,11 +24,17 @@ from unittest import mock
|
|||
import pytest
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.blockchain.eth.agents import ContractAgency, SubscriptionManagerAgent
|
||||
from nucypher.blockchain.eth.constants import NULL_ADDRESS
|
||||
from nucypher.policy.conditions.context import (
|
||||
USER_ADDRESS_CONTEXT,
|
||||
_recover_user_address,
|
||||
)
|
||||
from nucypher.policy.conditions.evm import RPCCondition, get_context_value
|
||||
from nucypher.policy.conditions.evm import (
|
||||
ContractCondition,
|
||||
RPCCondition,
|
||||
get_context_value,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ContextVariableVerificationFailed,
|
||||
InvalidContextVariableData,
|
||||
|
@ -91,7 +97,7 @@ def _dont_validate_user_address(context_variable: str, **context):
|
|||
|
||||
|
||||
def test_required_context_variable(
|
||||
testerchain, custom_context_variable_erc20_condition, condition_providers
|
||||
custom_context_variable_erc20_condition, condition_providers
|
||||
):
|
||||
with pytest.raises(RequiredContextVariable):
|
||||
custom_context_variable_erc20_condition.verify(
|
||||
|
@ -356,11 +362,7 @@ def test_subscription_manager_get_policy_policy_struct_condition_evaluation(
|
|||
|
||||
# zeroized policy struct
|
||||
zeroized_policy_struct = (
|
||||
"0x0000000000000000000000000000000000000000",
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
"0x0000000000000000000000000000000000000000",
|
||||
NULL_ADDRESS, 0, 0, 0, NULL_ADDRESS,
|
||||
)
|
||||
context = {
|
||||
":hrac": bytes(enacted_blockchain_policy.hrac),
|
||||
|
@ -381,6 +383,182 @@ def test_subscription_manager_get_policy_policy_struct_condition_evaluation(
|
|||
assert condition_result is True # zeroized policy was indeed returned
|
||||
|
||||
|
||||
def test_subscription_manager_get_policy_policy_struct_condition_key_tuple_evaluation(
|
||||
testerchain,
|
||||
agency,
|
||||
test_registry,
|
||||
idle_blockchain_policy,
|
||||
enacted_blockchain_policy,
|
||||
condition_providers,
|
||||
):
|
||||
# enacted policy created from idle policy
|
||||
size = len(idle_blockchain_policy.kfrags)
|
||||
start = idle_blockchain_policy.commencement
|
||||
end = idle_blockchain_policy.expiration
|
||||
sponsor = idle_blockchain_policy.publisher.checksum_address
|
||||
|
||||
context = {
|
||||
":hrac": bytes(enacted_blockchain_policy.hrac),
|
||||
} # user-defined context vars
|
||||
subscription_manager = ContractAgency.get_agent(
|
||||
SubscriptionManagerAgent, registry=test_registry
|
||||
)
|
||||
|
||||
# test "sponsor" key (owner is the same as sponsor for this policy)
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="==", value=sponsor, key=0),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
# test "sponsor" key not equal to correct value
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="!=", value=sponsor, key=0),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert not condition_result
|
||||
|
||||
# test "start" key
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="==", value=start, key=1),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
# test "start" key not equal to correct value
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="!=", value=start, key=1),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert not condition_result
|
||||
|
||||
# test "end" key
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="==", value=end, key=2),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
# test "size" key
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="==", value=size, key=3),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
# test "owner" key (owner is sponsor, so owner is set to null address)
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(comparator="==", value=NULL_ADDRESS, key=4),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
|
||||
def test_subscription_manager_get_policy_policy_struct_condition_key_context_var_evaluation(
|
||||
testerchain,
|
||||
agency,
|
||||
test_registry,
|
||||
idle_blockchain_policy,
|
||||
enacted_blockchain_policy,
|
||||
condition_providers,
|
||||
):
|
||||
# enacted policy created from idle policy
|
||||
sponsor = idle_blockchain_policy.publisher.checksum_address
|
||||
context = {
|
||||
":hrac": bytes(enacted_blockchain_policy.hrac),
|
||||
":sponsor": sponsor,
|
||||
":sponsorIndex": 0,
|
||||
} # user-defined context vars
|
||||
subscription_manager = ContractAgency.get_agent(
|
||||
SubscriptionManagerAgent, registry=test_registry
|
||||
)
|
||||
|
||||
# test "sponsor" key (owner is the same as sponsor for this policy)
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(
|
||||
comparator="==",
|
||||
value=sponsor, # don't use sponsor context var
|
||||
key=":sponsorIndex",
|
||||
),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert condition_result
|
||||
|
||||
# test "sponsor" key not equal to correct value
|
||||
condition = ContractCondition(
|
||||
contract_address=subscription_manager.contract.address,
|
||||
function_abi=subscription_manager.contract.get_function_by_name(
|
||||
"getPolicy"
|
||||
).abi,
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
return_value_test=ReturnValueTest(
|
||||
comparator="!=",
|
||||
value=":sponsor", # use sponsor sponsor context var
|
||||
key=":sponsorIndex",
|
||||
),
|
||||
parameters=[":hrac"],
|
||||
)
|
||||
condition_result, _ = condition.verify(providers=condition_providers, **context)
|
||||
assert not condition_result
|
||||
|
||||
|
||||
def test_time_condition_evaluation(testerchain, timelock_condition, condition_providers):
|
||||
condition_result, call_result = timelock_condition.verify(
|
||||
providers=condition_providers
|
||||
|
|
|
@ -16,15 +16,18 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import string
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import tempfile
|
||||
from random import SystemRandom
|
||||
|
||||
from web3 import Web3
|
||||
|
||||
from nucypher.blockchain.eth.token import NU
|
||||
from nucypher.config.constants import NUCYPHER_ENVVAR_KEYSTORE_PASSWORD, NUCYPHER_ENVVAR_OPERATOR_ETH_PASSWORD
|
||||
from nucypher.config.constants import (
|
||||
NUCYPHER_ENVVAR_KEYSTORE_PASSWORD,
|
||||
NUCYPHER_ENVVAR_OPERATOR_ETH_PASSWORD,
|
||||
)
|
||||
|
||||
#
|
||||
# Ursula
|
||||
|
|
|
@ -1,4 +1,46 @@
|
|||
{
|
||||
"TStaking": {
|
||||
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
|
||||
"chain": 131277322940537,
|
||||
"method": "stakes",
|
||||
"functionAbi": {
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "stakingProvider",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "stakes",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint96",
|
||||
"name": "tStake",
|
||||
"type": "uint96"
|
||||
},
|
||||
{
|
||||
"internalType": "uint96",
|
||||
"name": "keepInTStake",
|
||||
"type": "uint96"
|
||||
},
|
||||
{
|
||||
"internalType": "uint96",
|
||||
"name": "nuInTStake",
|
||||
"type": "uint96"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
"parameters": [
|
||||
":userAddress"
|
||||
],
|
||||
"returnValueTest": {
|
||||
"key": "tStake",
|
||||
"comparator": ">",
|
||||
"value": 0
|
||||
}
|
||||
},
|
||||
"SubscriptionManagerPayment": {
|
||||
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
"chain": 137,
|
||||
|
|
|
@ -38,7 +38,14 @@ with open(VECTORS_FILE, 'r') as file:
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_condition_blockchains(mocker):
|
||||
"""adds testerchain to permitted conditional chains"""
|
||||
mocker.patch.object(nucypher.policy.conditions.evm, '_CONDITION_CHAINS', tuple([131277322940537]))
|
||||
mocker.patch.object(
|
||||
nucypher.policy.conditions.evm, "_CONDITION_CHAINS", tuple([131277322940537])
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def t_staking_data():
|
||||
return json.dumps(VECTORS["TStaking"])
|
||||
|
||||
|
||||
# ERC1155
|
||||
|
|
|
@ -89,12 +89,25 @@ def test_invalid_contract_condition():
|
|||
# invalid ABI
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = ContractCondition(
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
function_abi=["rando ABI"],
|
||||
return_value_test=ReturnValueTest('!=', 0),
|
||||
parameters=[
|
||||
':hrac',
|
||||
]
|
||||
)
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
function_abi={"rando": "ABI"},
|
||||
return_value_test=ReturnValueTest("!=", 0),
|
||||
parameters=[
|
||||
":hrac",
|
||||
],
|
||||
)
|
||||
|
||||
# method not in ABI
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = ContractCondition(
|
||||
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
method="getPolicy",
|
||||
chain=TESTERCHAIN_CHAIN_ID,
|
||||
standard_contract_type="ERC20",
|
||||
return_value_test=ReturnValueTest("!=", 0),
|
||||
parameters=[
|
||||
":hrac",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -22,12 +22,20 @@ from nucypher.policy.conditions.evm import ContractCondition
|
|||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
|
||||
|
||||
def test_evm_condition_function_abi(t_staking_data):
|
||||
original_data = t_staking_data
|
||||
condition = ContractCondition.from_json(original_data)
|
||||
serialized_data = condition.to_json()
|
||||
|
||||
deserialized_data = json.loads(serialized_data)
|
||||
assert deserialized_data["functionAbi"] == condition.function_abi
|
||||
|
||||
|
||||
def test_evm_condition_json_serializers(ERC1155_balance_condition_data):
|
||||
original_data = ERC1155_balance_condition_data
|
||||
condition = ContractCondition.from_json(original_data)
|
||||
serialized_data = condition.to_json()
|
||||
|
||||
# TODO functionAbi is present in serialized data
|
||||
deserialized_data = json.loads(serialized_data)
|
||||
deserialized_data.pop("functionAbi")
|
||||
|
||||
|
|
|
@ -20,9 +20,51 @@ import random
|
|||
|
||||
import pytest
|
||||
|
||||
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
|
||||
from nucypher.policy.conditions.lingo import ReturnValueTest
|
||||
|
||||
|
||||
def test_return_value_key():
|
||||
test = ReturnValueTest(comparator=">", value="0", key="james")
|
||||
assert test.eval({"james": 1})
|
||||
assert not test.eval({"james": -1})
|
||||
|
||||
with pytest.raises(ReturnValueEvaluationError):
|
||||
test.eval({"bond": 1})
|
||||
|
||||
test = ReturnValueTest(comparator=">", value="0", key=4)
|
||||
assert test.eval({4: 1})
|
||||
assert not test.eval({4: -1})
|
||||
|
||||
with pytest.raises(ReturnValueEvaluationError):
|
||||
test.eval({5: 1})
|
||||
|
||||
|
||||
def test_return_value_index():
|
||||
test = ReturnValueTest(comparator=">", value="0", key=0)
|
||||
assert test.eval([1])
|
||||
assert not test.eval([-1])
|
||||
|
||||
test = ReturnValueTest(comparator="==", value='"james"', key=3)
|
||||
assert test.eval([0, 1, 2, '"james"'])
|
||||
|
||||
with pytest.raises(ReturnValueEvaluationError):
|
||||
test.eval([0, 1, 2])
|
||||
|
||||
|
||||
def test_return_value_index_tuple():
|
||||
test = ReturnValueTest(comparator=">", value="0", key=0)
|
||||
assert test.eval((1,))
|
||||
assert not test.eval((-1,))
|
||||
|
||||
|
||||
def test_return_value_with_context_variable_key_cant_run_eval():
|
||||
# known context variable
|
||||
test = ReturnValueTest(comparator="==", value="0", key=":userAddress")
|
||||
with pytest.raises(RuntimeError):
|
||||
test.eval({"0xaDD9D957170dF6F33982001E4c22eCCdd5539118": 0})
|
||||
|
||||
|
||||
def test_return_value_test_invalid_comparators():
|
||||
with pytest.raises(ReturnValueTest.InvalidExpression):
|
||||
_ = ReturnValueTest(comparator="eq", value=1)
|
||||
|
|
Loading…
Reference in New Issue