mirror of https://github.com/nucypher/nucypher.git
Add deeper resolution of context variables
Now we can recursively resolve list of variablespull/3344/head
parent
4e6af23e81
commit
8fde668e8d
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
from eth_account.account import Account
|
||||
from eth_account.messages import HexBytes, encode_structured_data
|
||||
|
@ -102,15 +102,21 @@ def get_context_value(context_variable: str, **context) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
def resolve_any_context_variables(parameters: List[Any], return_value_test, **context):
|
||||
processed_parameters = []
|
||||
for p in parameters:
|
||||
# TODO needs additional support for ERC1155 which has lists of values
|
||||
# context variables can only be strings, but other types of parameters can be passed
|
||||
if is_context_variable(p):
|
||||
p = get_context_value(context_variable=p, **context)
|
||||
processed_parameters.append(p)
|
||||
def _resolve_context_variable(input: Union[Any, List[Any]], **context):
|
||||
if isinstance(input, list):
|
||||
return [_resolve_context_variable(item, **context) for item in input]
|
||||
elif is_context_variable(input):
|
||||
return get_context_value(context_variable=input, **context)
|
||||
else:
|
||||
return input
|
||||
|
||||
|
||||
def resolve_any_context_variables(parameters: List[Any], return_value_test, **context):
|
||||
# TODO needs additional support for ERC1155 which has lists of values
|
||||
# context variables can only be strings, but other types of parameters can be passed
|
||||
processed_parameters = [
|
||||
_resolve_context_variable(param, **context) for param in parameters
|
||||
]
|
||||
processed_return_value_test = return_value_test.with_resolved_context(**context)
|
||||
|
||||
return processed_parameters, processed_return_value_test
|
||||
|
|
|
@ -21,7 +21,10 @@ from marshmallow.validate import OneOf, Range
|
|||
from packaging.version import parse as parse_version
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
|
||||
from nucypher.policy.conditions.context import get_context_value, is_context_variable
|
||||
from nucypher.policy.conditions.context import (
|
||||
_resolve_context_variable,
|
||||
is_context_variable,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
|
@ -316,9 +319,7 @@ class ReturnValueTest:
|
|||
return result
|
||||
|
||||
def with_resolved_context(self, **context):
|
||||
value = self.value
|
||||
if is_context_variable(value):
|
||||
value = get_context_value(context_variable=value, **context)
|
||||
value = _resolve_context_variable(self.value, **context)
|
||||
return ReturnValueTest(self.comparator, value=value, index=self.index)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue