Merge pull request #3002 from theref/key-value

Add custom abi conditions and key/value return values
pull/3013/head
KPrasch 2022-11-14 22:11:52 +00:00 committed by GitHub
commit 131a7d4ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 414 additions and 47 deletions

View File

@ -0,0 +1 @@
Allow a key to be specified for evaluating the return value

View File

@ -17,7 +17,7 @@
import json
from http import HTTPStatus
from typing import Dict, Optional, Tuple, Type, Union, NamedTuple
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union
from marshmallow import Schema, post_dump
from web3.providers import BaseProvider
@ -30,6 +30,7 @@ from nucypher.policy.conditions.exceptions import (
InvalidContextVariableData,
NoConnectionToChain,
RequiredContextVariable,
ReturnValueEvaluationError,
)
from nucypher.utilities.logging import Logger
@ -129,6 +130,11 @@ def evaluate_conditions(
if not result:
# explicit condition failure
error = ("Decryption conditions not satisfied", HTTPStatus.FORBIDDEN)
except ReturnValueEvaluationError as e:
error = (
f"Unable to evaluate return value: {e}",
HTTPStatus.BAD_REQUEST,
)
except InvalidCondition as e:
error = (
f"Incorrect value provided for condition: {e}",

View File

@ -108,7 +108,7 @@ def get_context_value(context_variable: str, **context) -> Any:
# fallback for context variable without directive - assume key,value pair
# handles the case for user customized context variables
value = context.get(context_variable)
if not value:
if value is None:
raise RequiredContextVariable(
f'"No value provided for unrecognized context variable "{context_variable}"'
)

View File

@ -24,6 +24,7 @@ from marshmallow import fields, post_load
from web3 import Web3
from web3.contract import ContractFunction
from web3.providers import BaseProvider
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES, STANDARD_ABIS
from nucypher.policy.conditions._utils import CamelCaseSchema
@ -45,28 +46,39 @@ _CONDITION_CHAINS = (
)
def _resolve_abi(standard_contract_type: str, method: str, function_abi: List) -> List:
def _resolve_abi(
w3: Web3,
method: str,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
) -> ABIFunction:
"""Resolves the contract an/or function ABI from a standard contract name"""
if not (function_abi or standard_contract_type):
# TODO: Is this protection needed?
raise InvalidCondition(
f"Ambiguous ABI - Supply either an ABI or a standard contract type ({STANDARD_ABI_CONTRACT_TYPES})."
)
if standard_contract_type:
try:
function_abi = STANDARD_ABIS[standard_contract_type]
# Lookup the standard ABI given it's ERC standard name (standard contract type)
contract_abi = STANDARD_ABIS[standard_contract_type]
except KeyError:
raise InvalidCondition(
f"Invalid standard contract type {standard_contract_type}; Must be one of {STANDARD_ABI_CONTRACT_TYPES}"
)
try:
# Extract all function ABIs from the contract's ABI.
# Will raise a ValueError if there is not exactly one match.
function_abi = w3.eth.contract(abi=contract_abi).get_function_by_name(method).abi
except ValueError as e:
raise InvalidCondition(str(e))
if not function_abi:
raise InvalidCondition(f"No function ABI supplied for '{method}'")
# TODO: Verify that the function and ABI pair match?
# ABI(function_abi)
return function_abi
return ABIFunction(function_abi)
def camel_case_to_snake(data: str) -> str:
@ -86,14 +98,19 @@ def _resolve_any_context_variables(
processed_parameters.append(p)
v = return_value_test.value
k = return_value_test.key
if is_context_variable(return_value_test.value):
v = get_context_value(context_variable=v, **context)
processed_return_value_test = ReturnValueTest(return_value_test.comparator, value=v)
if is_context_variable(return_value_test.key):
k = get_context_value(context_variable=k, **context)
processed_return_value_test = ReturnValueTest(
return_value_test.comparator, value=v, key=k
)
return processed_parameters, processed_return_value_test
def _validate_chain(chain: int):
def _validate_chain(chain: int) -> None:
if not isinstance(chain, int):
raise ValueError(f'"The chain" field of c a condition must be the '
f'integer of a chain ID (got "{chain}").')
@ -216,17 +233,20 @@ class ContractCondition(RPCCondition):
SKIP_VALUES = (None,)
standard_contract_type = fields.Str(required=False)
contract_address = fields.Str(required=True)
function_abi = fields.Str(required=False)
function_abi = fields.Dict(required=False)
@post_load
def make(self, data, **kwargs):
return ContractCondition(**data)
def __init__(self,
contract_address: ChecksumAddress,
standard_contract_type: str = None,
function_abi: List = None,
*args, **kwargs):
def __init__(
self,
contract_address: ChecksumAddress,
standard_contract_type: Optional[str] = None,
function_abi: Optional[ABIFunction] = None,
*args,
**kwargs
):
# internal
super().__init__(*args, **kwargs)
self.w3 = Web3() # used to instantiate contract function without a provider
@ -234,6 +254,7 @@ class ContractCondition(RPCCondition):
# preprocessing
contract_address = to_checksum_address(contract_address)
function_abi = _resolve_abi(
w3=self.w3,
standard_contract_type=standard_contract_type,
method=self.method,
function_abi=function_abi
@ -262,7 +283,7 @@ class ContractCondition(RPCCondition):
"""Gets an unbound contract function to evaluate for this condition"""
try:
contract = self.w3.eth.contract(
address=self.contract_address, abi=self.function_abi
address=self.contract_address, abi=[self.function_abi]
)
contract_function = getattr(contract.functions, self.method)
return contract_function

View File

@ -26,6 +26,9 @@ class NoConnectionToChain(RuntimeError):
super().__init__(message)
class ReturnValueEvaluationError(Exception):
"""Issue with Return Value and Key"""
# Context Variable
class RequiredContextVariable(Exception):
"""No value provided for context variable"""

View File

@ -21,7 +21,7 @@ import base64
import json
import operator as pyoperator
from hashlib import md5
from typing import Any, Dict, Iterator, List, Union
from typing import Any, Dict, Iterator, List, Optional, Union
from marshmallow import fields, post_load
@ -31,6 +31,7 @@ from nucypher.policy.conditions._utils import (
)
from nucypher.policy.conditions.base import ReencryptionCondition
from nucypher.policy.conditions.context import is_context_variable
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
class Operator:
@ -83,19 +84,26 @@ class ReturnValueTest:
COMPARATORS = tuple(_COMPARATOR_FUNCTIONS)
class ReturnValueTestSchema(CamelCaseSchema):
SKIP_VALUES = (None,)
comparator = fields.Str()
value = fields.Raw(allow_none=False) # any valid type (excludes None)
key = fields.Raw(allow_none=True)
@post_load
def make(self, data, **kwargs):
return ReturnValueTest(**data)
def __init__(self, comparator: str, value: Any):
def __init__(self, comparator: str, value: Any, key: Optional[Union[int, str]] = None):
if comparator not in self.COMPARATORS:
raise self.InvalidExpression(
f'"{comparator}" is not a permitted comparator.'
)
if not isinstance(key, (int, str)) and key is not None:
raise self.InvalidExpression(
f'"{key}" is not a permitted key. Must be a string or integer.'
)
if not is_context_variable(value):
# verify that value is valid, but don't set it here so as not to change the value;
# it will be sanitized at eval time. Need to maintain serialization/deserialization
@ -104,6 +112,7 @@ class ReturnValueTest:
self.comparator = comparator
self.value = value
self.key = key
def _sanitize_value(self, value):
try:
@ -111,14 +120,44 @@ class ReturnValueTest:
except Exception:
raise self.InvalidExpression(f'"{value}" is not a permitted value.')
def _process_data(self, data: Any) -> Any:
"""
If a key is specified, return the value at that key in the data if data is a dict or list-like.
Otherwise, return the data.
"""
processed_data = data
if self.key is not None:
if isinstance(data, dict):
try:
processed_data = data[self.key]
except KeyError:
raise ReturnValueEvaluationError(
f"Key '{self.key}' not found in return data."
)
elif isinstance(self.key, int) and isinstance(data, (list, tuple)):
try:
processed_data = data[self.key]
except IndexError:
raise ReturnValueEvaluationError(
f"Index '{self.key}' not found in return data."
)
else:
raise ReturnValueEvaluationError(
f"Key: {self.key} and Value: {data} are not compatible types."
)
return processed_data
def eval(self, data) -> bool:
if is_context_variable(self.value):
if is_context_variable(self.value) or is_context_variable(self.key):
# programming error if we get here
raise RuntimeError(
f"'{self.value}' is an unprocessed context variable and is not valid "
f"Return value comparator contains an unprocessed context variable (key={self.key}, value={self.value}) and is not valid "
f"for condition evaluation."
)
left_operand = self._sanitize_value(data)
processed_data = self._process_data(data)
left_operand = self._sanitize_value(processed_data)
right_operand = self._sanitize_value(self.value)
result = self._COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
return result

View File

@ -49,7 +49,9 @@ with open(VECTORS_FILE, 'r') as file:
@pytest.fixture(autouse=True)
def mock_condition_blockchains(mocker):
"""adds testerchain to permitted conditional chains"""
mocker.patch.object(nucypher.policy.conditions.evm, '_CONDITION_CHAINS', tuple([131277322940537]))
mocker.patch.object(
nucypher.policy.conditions.evm, "_CONDITION_CHAINS", tuple([131277322940537])
)
@pytest.fixture()
@ -106,7 +108,9 @@ def erc20_evm_condition(test_registry, agency):
@pytest.fixture
def custom_context_variable_erc20_condition(test_registry, agency, testerchain):
def custom_context_variable_erc20_condition(
test_registry, agency, testerchain, mock_condition_blockchains
):
token = ContractAgency.get_agent(NucypherTokenAgent, registry=test_registry)
condition = ContractCondition(
contract_address=token.contract.address,
@ -181,7 +185,7 @@ def subscription_manager_is_active_policy_condition(test_registry, agency):
)
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=[subscription_manager.contract.find_functions_by_name('isPolicyActive')[0].abi],
function_abi=subscription_manager.contract.get_function_by_name("isPolicyActive").abi,
method="isPolicyActive",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest("==", True),
@ -199,7 +203,7 @@ def subscription_manager_get_policy_zeroized_policy_struct_condition(
)
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=[subscription_manager.contract.find_functions_by_name('getPolicy')[0].abi],
function_abi=subscription_manager.contract.get_function_by_name("getPolicy").abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest("==", ":expectedPolicyStruct"),

View File

@ -24,11 +24,17 @@ from unittest import mock
import pytest
from web3 import Web3
from nucypher.blockchain.eth.agents import ContractAgency, SubscriptionManagerAgent
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.policy.conditions.context import (
USER_ADDRESS_CONTEXT,
_recover_user_address,
)
from nucypher.policy.conditions.evm import RPCCondition, get_context_value
from nucypher.policy.conditions.evm import (
ContractCondition,
RPCCondition,
get_context_value,
)
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
InvalidContextVariableData,
@ -91,7 +97,7 @@ def _dont_validate_user_address(context_variable: str, **context):
def test_required_context_variable(
testerchain, custom_context_variable_erc20_condition, condition_providers
custom_context_variable_erc20_condition, condition_providers
):
with pytest.raises(RequiredContextVariable):
custom_context_variable_erc20_condition.verify(
@ -356,11 +362,7 @@ def test_subscription_manager_get_policy_policy_struct_condition_evaluation(
# zeroized policy struct
zeroized_policy_struct = (
"0x0000000000000000000000000000000000000000",
0,
0,
0,
"0x0000000000000000000000000000000000000000",
NULL_ADDRESS, 0, 0, 0, NULL_ADDRESS,
)
context = {
":hrac": bytes(enacted_blockchain_policy.hrac),
@ -381,6 +383,182 @@ def test_subscription_manager_get_policy_policy_struct_condition_evaluation(
assert condition_result is True # zeroized policy was indeed returned
def test_subscription_manager_get_policy_policy_struct_condition_key_tuple_evaluation(
testerchain,
agency,
test_registry,
idle_blockchain_policy,
enacted_blockchain_policy,
condition_providers,
):
# enacted policy created from idle policy
size = len(idle_blockchain_policy.kfrags)
start = idle_blockchain_policy.commencement
end = idle_blockchain_policy.expiration
sponsor = idle_blockchain_policy.publisher.checksum_address
context = {
":hrac": bytes(enacted_blockchain_policy.hrac),
} # user-defined context vars
subscription_manager = ContractAgency.get_agent(
SubscriptionManagerAgent, registry=test_registry
)
# test "sponsor" key (owner is the same as sponsor for this policy)
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="==", value=sponsor, key=0),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
# test "sponsor" key not equal to correct value
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="!=", value=sponsor, key=0),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert not condition_result
# test "start" key
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="==", value=start, key=1),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
# test "start" key not equal to correct value
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="!=", value=start, key=1),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert not condition_result
# test "end" key
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="==", value=end, key=2),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
# test "size" key
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="==", value=size, key=3),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
# test "owner" key (owner is sponsor, so owner is set to null address)
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(comparator="==", value=NULL_ADDRESS, key=4),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
def test_subscription_manager_get_policy_policy_struct_condition_key_context_var_evaluation(
testerchain,
agency,
test_registry,
idle_blockchain_policy,
enacted_blockchain_policy,
condition_providers,
):
# enacted policy created from idle policy
sponsor = idle_blockchain_policy.publisher.checksum_address
context = {
":hrac": bytes(enacted_blockchain_policy.hrac),
":sponsor": sponsor,
":sponsorIndex": 0,
} # user-defined context vars
subscription_manager = ContractAgency.get_agent(
SubscriptionManagerAgent, registry=test_registry
)
# test "sponsor" key (owner is the same as sponsor for this policy)
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(
comparator="==",
value=sponsor, # don't use sponsor context var
key=":sponsorIndex",
),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert condition_result
# test "sponsor" key not equal to correct value
condition = ContractCondition(
contract_address=subscription_manager.contract.address,
function_abi=subscription_manager.contract.get_function_by_name(
"getPolicy"
).abi,
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest(
comparator="!=",
value=":sponsor", # use sponsor sponsor context var
key=":sponsorIndex",
),
parameters=[":hrac"],
)
condition_result, _ = condition.verify(providers=condition_providers, **context)
assert not condition_result
def test_time_condition_evaluation(testerchain, timelock_condition, condition_providers):
condition_result, call_result = timelock_condition.verify(
providers=condition_providers

View File

@ -16,15 +16,18 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import string
import tempfile
from datetime import datetime
from pathlib import Path
import tempfile
from random import SystemRandom
from web3 import Web3
from nucypher.blockchain.eth.token import NU
from nucypher.config.constants import NUCYPHER_ENVVAR_KEYSTORE_PASSWORD, NUCYPHER_ENVVAR_OPERATOR_ETH_PASSWORD
from nucypher.config.constants import (
NUCYPHER_ENVVAR_KEYSTORE_PASSWORD,
NUCYPHER_ENVVAR_OPERATOR_ETH_PASSWORD,
)
#
# Ursula

View File

@ -1,4 +1,46 @@
{
"TStaking": {
"contractAddress": "0x01B67b1194C75264d06F808A921228a95C765dd7",
"chain": 131277322940537,
"method": "stakes",
"functionAbi": {
"inputs": [
{
"internalType": "address",
"name": "stakingProvider",
"type": "address"
}
],
"name": "stakes",
"outputs": [
{
"internalType": "uint96",
"name": "tStake",
"type": "uint96"
},
{
"internalType": "uint96",
"name": "keepInTStake",
"type": "uint96"
},
{
"internalType": "uint96",
"name": "nuInTStake",
"type": "uint96"
}
],
"stateMutability": "view",
"type": "function"
},
"parameters": [
":userAddress"
],
"returnValueTest": {
"key": "tStake",
"comparator": ">",
"value": 0
}
},
"SubscriptionManagerPayment": {
"contractAddress": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"chain": 137,

View File

@ -38,7 +38,14 @@ with open(VECTORS_FILE, 'r') as file:
@pytest.fixture(autouse=True)
def mock_condition_blockchains(mocker):
"""adds testerchain to permitted conditional chains"""
mocker.patch.object(nucypher.policy.conditions.evm, '_CONDITION_CHAINS', tuple([131277322940537]))
mocker.patch.object(
nucypher.policy.conditions.evm, "_CONDITION_CHAINS", tuple([131277322940537])
)
@pytest.fixture()
def t_staking_data():
return json.dumps(VECTORS["TStaking"])
# ERC1155

View File

@ -89,12 +89,25 @@ def test_invalid_contract_condition():
# invalid ABI
with pytest.raises(InvalidCondition):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
function_abi=["rando ABI"],
return_value_test=ReturnValueTest('!=', 0),
parameters=[
':hrac',
]
)
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
function_abi={"rando": "ABI"},
return_value_test=ReturnValueTest("!=", 0),
parameters=[
":hrac",
],
)
# method not in ABI
with pytest.raises(InvalidCondition):
_ = ContractCondition(
contract_address="0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
method="getPolicy",
chain=TESTERCHAIN_CHAIN_ID,
standard_contract_type="ERC20",
return_value_test=ReturnValueTest("!=", 0),
parameters=[
":hrac",
],
)

View File

@ -22,12 +22,20 @@ from nucypher.policy.conditions.evm import ContractCondition
from nucypher.policy.conditions.lingo import ConditionLingo
def test_evm_condition_function_abi(t_staking_data):
original_data = t_staking_data
condition = ContractCondition.from_json(original_data)
serialized_data = condition.to_json()
deserialized_data = json.loads(serialized_data)
assert deserialized_data["functionAbi"] == condition.function_abi
def test_evm_condition_json_serializers(ERC1155_balance_condition_data):
original_data = ERC1155_balance_condition_data
condition = ContractCondition.from_json(original_data)
serialized_data = condition.to_json()
# TODO functionAbi is present in serialized data
deserialized_data = json.loads(serialized_data)
deserialized_data.pop("functionAbi")

View File

@ -20,9 +20,51 @@ import random
import pytest
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
from nucypher.policy.conditions.lingo import ReturnValueTest
def test_return_value_key():
test = ReturnValueTest(comparator=">", value="0", key="james")
assert test.eval({"james": 1})
assert not test.eval({"james": -1})
with pytest.raises(ReturnValueEvaluationError):
test.eval({"bond": 1})
test = ReturnValueTest(comparator=">", value="0", key=4)
assert test.eval({4: 1})
assert not test.eval({4: -1})
with pytest.raises(ReturnValueEvaluationError):
test.eval({5: 1})
def test_return_value_index():
test = ReturnValueTest(comparator=">", value="0", key=0)
assert test.eval([1])
assert not test.eval([-1])
test = ReturnValueTest(comparator="==", value='"james"', key=3)
assert test.eval([0, 1, 2, '"james"'])
with pytest.raises(ReturnValueEvaluationError):
test.eval([0, 1, 2])
def test_return_value_index_tuple():
test = ReturnValueTest(comparator=">", value="0", key=0)
assert test.eval((1,))
assert not test.eval((-1,))
def test_return_value_with_context_variable_key_cant_run_eval():
# known context variable
test = ReturnValueTest(comparator="==", value="0", key=":userAddress")
with pytest.raises(RuntimeError):
test.eval({"0xaDD9D957170dF6F33982001E4c22eCCdd5539118": 0})
def test_return_value_test_invalid_comparators():
with pytest.raises(ReturnValueTest.InvalidExpression):
_ = ReturnValueTest(comparator="eq", value=1)