From 8fde668e8dae1653686ed9b4ee468e7934000ddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20N=C3=BA=C3=B1ez?= Date: Mon, 13 Nov 2023 16:51:55 +0100 Subject: [PATCH] Add deeper resolution of context variables Now we can recursively resolve list of variables --- nucypher/policy/conditions/context.py | 24 +++++++++++++++--------- nucypher/policy/conditions/lingo.py | 9 +++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/nucypher/policy/conditions/context.py b/nucypher/policy/conditions/context.py index 5a107ece1..c742b1648 100644 --- a/nucypher/policy/conditions/context.py +++ b/nucypher/policy/conditions/context.py @@ -1,5 +1,5 @@ import re -from typing import Any, List +from typing import Any, List, Union from eth_account.account import Account from eth_account.messages import HexBytes, encode_structured_data @@ -102,15 +102,21 @@ def get_context_value(context_variable: str, **context) -> Any: return value -def resolve_any_context_variables(parameters: List[Any], return_value_test, **context): - processed_parameters = [] - for p in parameters: - # TODO needs additional support for ERC1155 which has lists of values - # context variables can only be strings, but other types of parameters can be passed - if is_context_variable(p): - p = get_context_value(context_variable=p, **context) - processed_parameters.append(p) +def _resolve_context_variable(input: Union[Any, List[Any]], **context): + if isinstance(input, list): + return [_resolve_context_variable(item, **context) for item in input] + elif is_context_variable(input): + return get_context_value(context_variable=input, **context) + else: + return input + +def resolve_any_context_variables(parameters: List[Any], return_value_test, **context): + # TODO needs additional support for ERC1155 which has lists of values + # context variables can only be strings, but other types of parameters can be passed + processed_parameters = [ + _resolve_context_variable(param, **context) for param in parameters + ] processed_return_value_test = return_value_test.with_resolved_context(**context) return processed_parameters, processed_return_value_test diff --git a/nucypher/policy/conditions/lingo.py b/nucypher/policy/conditions/lingo.py index 9b781b5e6..38079b765 100644 --- a/nucypher/policy/conditions/lingo.py +++ b/nucypher/policy/conditions/lingo.py @@ -21,7 +21,10 @@ from marshmallow.validate import OneOf, Range from packaging.version import parse as parse_version from nucypher.policy.conditions.base import AccessControlCondition, _Serializable -from nucypher.policy.conditions.context import get_context_value, is_context_variable +from nucypher.policy.conditions.context import ( + _resolve_context_variable, + is_context_variable, +) from nucypher.policy.conditions.exceptions import ( InvalidCondition, InvalidConditionLingo, @@ -316,9 +319,7 @@ class ReturnValueTest: return result def with_resolved_context(self, **context): - value = self.value - if is_context_variable(value): - value = get_context_value(context_variable=value, **context) + value = _resolve_context_variable(self.value, **context) return ReturnValueTest(self.comparator, value=value, index=self.index)