mirror of https://github.com/nucypher/nucypher.git
Allow context variables to resolved even if they are a substring within a larger string, or a dictionary value etc.; all context variables (string, list, dict) can now resolved via the same method.
parent
4a9b391778
commit
b5e35a7188
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
|
@ -90,12 +90,7 @@ _DIRECTIVES = {
|
|||
|
||||
|
||||
def is_context_variable(variable) -> bool:
|
||||
if isinstance(variable, str) and variable.startswith(CONTEXT_PREFIX):
|
||||
if CONTEXT_REGEX.fullmatch(variable):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Context variable name '{variable}' is not valid.")
|
||||
return False
|
||||
return isinstance(variable, str) and CONTEXT_REGEX.fullmatch(variable)
|
||||
|
||||
|
||||
def get_context_value(context_variable: str, **context) -> Any:
|
||||
|
@ -116,20 +111,31 @@ def get_context_value(context_variable: str, **context) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
def resolve_context_variable(param: Union[Any, List[Any]], **context):
|
||||
def resolve_any_context_variables(
|
||||
param: Union[Any, List[Any], Dict[Any, Any]], **context
|
||||
):
|
||||
if isinstance(param, list):
|
||||
return [resolve_context_variable(item, **context) for item in param]
|
||||
elif is_context_variable(param):
|
||||
return get_context_value(context_variable=param, **context)
|
||||
return [resolve_any_context_variables(item, **context) for item in param]
|
||||
elif isinstance(param, dict):
|
||||
result = {}
|
||||
for k, v in param.items():
|
||||
result[k] = resolve_any_context_variables(v, **context)
|
||||
return result
|
||||
elif isinstance(param, str):
|
||||
# either it is a context variable OR contains a context variable within it
|
||||
# TODO separating the two cases for now out of concern of regex searching
|
||||
# within strings (case 2)
|
||||
if is_context_variable(param):
|
||||
return get_context_value(context_variable=param, **context)
|
||||
else:
|
||||
matches = re.findall(CONTEXT_REGEX, param)
|
||||
for context_var in matches:
|
||||
# checking out of concern for faulty regex search within string
|
||||
if context_var in context:
|
||||
resolved_var = get_context_value(
|
||||
context_variable=context_var, **context
|
||||
)
|
||||
param = param.replace(context_var, str(resolved_var))
|
||||
return param
|
||||
else:
|
||||
return param
|
||||
|
||||
|
||||
def resolve_parameter_context_variables(parameters: Optional[List[Any]], **context):
|
||||
if not parameters:
|
||||
processed_parameters = [] # produce empty list
|
||||
else:
|
||||
processed_parameters = [
|
||||
resolve_context_variable(param, **context) for param in parameters
|
||||
]
|
||||
return processed_parameters
|
||||
|
|
|
@ -31,7 +31,7 @@ from nucypher.policy.conditions.base import (
|
|||
)
|
||||
from nucypher.policy.conditions.context import (
|
||||
is_context_variable,
|
||||
resolve_parameter_context_variables,
|
||||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
NoConnectionToChain,
|
||||
|
@ -169,9 +169,11 @@ class RPCCall(ExecutionCall):
|
|||
yield provider
|
||||
|
||||
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
|
||||
resolved_parameters = resolve_parameter_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
resolved_parameters = []
|
||||
if self.parameters:
|
||||
resolved_parameters = resolve_any_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
latest_error = ""
|
||||
|
|
|
@ -29,7 +29,7 @@ from nucypher.policy.conditions.base import (
|
|||
)
|
||||
from nucypher.policy.conditions.context import (
|
||||
is_context_variable,
|
||||
resolve_context_variable,
|
||||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
|
@ -611,7 +611,7 @@ class ReturnValueTest:
|
|||
return result
|
||||
|
||||
def with_resolved_context(self, **context):
|
||||
value = resolve_context_variable(self.value, **context)
|
||||
value = resolve_any_context_variables(self.value, **context)
|
||||
return ReturnValueTest(self.comparator, value=value, index=self.index)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import copy
|
||||
import itertools
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -10,8 +9,7 @@ from nucypher.policy.conditions.context import (
|
|||
_resolve_user_address,
|
||||
get_context_value,
|
||||
is_context_variable,
|
||||
resolve_context_variable,
|
||||
resolve_parameter_context_variables,
|
||||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
ContextVariableVerificationFailed,
|
||||
|
@ -67,16 +65,12 @@ def test_is_context_variable():
|
|||
assert not is_context_variable(variable)
|
||||
|
||||
for variable in INVALID_CONTEXT_PARAM_NAMES:
|
||||
expected_message = re.escape(
|
||||
f"Context variable name '{variable}' is not valid."
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_message):
|
||||
_ = is_context_variable(variable)
|
||||
assert not is_context_variable(variable)
|
||||
|
||||
|
||||
def test_resolve_context_variable():
|
||||
for value, resolution in VALUES_WITH_RESOLUTION:
|
||||
assert resolution == resolve_context_variable(value, **CONTEXT)
|
||||
assert resolution == resolve_any_context_variables(value, **CONTEXT)
|
||||
|
||||
|
||||
def test_resolve_any_context_variables():
|
||||
|
@ -86,7 +80,7 @@ def test_resolve_any_context_variables():
|
|||
params, resolved_params = params_with_resolution
|
||||
value, resolved_value = value_with_resolution
|
||||
return_value_test = ReturnValueTest(comparator="==", value=value)
|
||||
resolved_parameters = resolve_parameter_context_variables([params], **CONTEXT)
|
||||
resolved_parameters = resolve_any_context_variables([params], **CONTEXT)
|
||||
resolved_return_value = return_value_test.with_resolved_context(**CONTEXT)
|
||||
assert resolved_parameters == [resolved_params]
|
||||
assert resolved_return_value.comparator == return_value_test.comparator
|
||||
|
@ -94,6 +88,78 @@ def test_resolve_any_context_variables():
|
|||
assert resolved_return_value.value == resolved_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_resolution",
|
||||
[
|
||||
(
|
||||
"https://api.github.com/user/:foo/:bar",
|
||||
"https://api.github.com/user/1234/BAR",
|
||||
),
|
||||
(
|
||||
"The cost of :bar is $:foo; $:foo is too expensive for :bar",
|
||||
"The cost of BAR is $1234; $1234 is too expensive for BAR",
|
||||
),
|
||||
# graphql query
|
||||
(
|
||||
"""{
|
||||
organization(login: ":bar") {
|
||||
teams(first: :foo, userLogins: [":bar"]) {
|
||||
totalCount
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}""",
|
||||
"""{
|
||||
organization(login: "BAR") {
|
||||
teams(first: 1234, userLogins: ["BAR"]) {
|
||||
totalCount
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_context_variable_within_substring(value, expected_resolution):
|
||||
context = {":foo": 1234, ":bar": "BAR"}
|
||||
resolved_value = resolve_any_context_variables(value, **context)
|
||||
assert expected_resolution == resolved_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_resolution",
|
||||
[
|
||||
(
|
||||
{
|
||||
"book_name": ":bar",
|
||||
"price": "$:foo",
|
||||
"description": ":bar is a book about foo and bar.",
|
||||
},
|
||||
{
|
||||
"book_name": "BAR",
|
||||
"price": "$1234",
|
||||
"description": "BAR is a book about foo and bar.",
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_resolve_context_variable_within_dictionary(value, expected_resolution):
|
||||
context = {":foo": 1234, ":bar": "BAR"}
|
||||
resolved_value = resolve_any_context_variables(value, **context)
|
||||
assert expected_resolution == resolved_value
|
||||
|
||||
@pytest.mark.parametrize("expected_entry", ["address", "signature", "typedData"])
|
||||
@pytest.mark.parametrize(
|
||||
"context_variable_name, valid_user_address_fixture",
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import NamedTuple
|
|||
import pytest
|
||||
from hexbytes import HexBytes
|
||||
|
||||
from nucypher.policy.conditions.context import resolve_context_variable
|
||||
from nucypher.policy.conditions.context import resolve_any_context_variables
|
||||
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
|
||||
from nucypher.policy.conditions.lingo import ReturnValueTest
|
||||
|
||||
|
@ -150,14 +150,14 @@ def test_return_value_test_with_resolved_context():
|
|||
resolved = test.with_resolved_context(**context)
|
||||
assert resolved.comparator == test.comparator
|
||||
assert resolved.index == test.index
|
||||
assert resolved.value == resolve_context_variable(test.value, **context)
|
||||
assert resolved.value == resolve_any_context_variables(test.value, **context)
|
||||
|
||||
test = ReturnValueTest(comparator="==", value=[42, ":foo"])
|
||||
|
||||
resolved = test.with_resolved_context(**context)
|
||||
assert resolved.comparator == test.comparator
|
||||
assert resolved.index == test.index
|
||||
assert resolved.value == resolve_context_variable(test.value, **context)
|
||||
assert resolved.value == resolve_any_context_variables(test.value, **context)
|
||||
|
||||
|
||||
def test_return_value_test_integer():
|
||||
|
|
Loading…
Reference in New Issue