Allow context variables to resolved even if they are a substring within a larger string, or a dictionary value etc.; all context variables (string, list, dict) can now resolved via the same method.

pull/3560/head
derekpierre 2024-10-25 11:15:00 -04:00
parent 4a9b391778
commit b5e35a7188
No known key found for this signature in database
5 changed files with 114 additions and 40 deletions

View File

@ -1,6 +1,6 @@
import re
from functools import partial
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
@ -90,12 +90,7 @@ _DIRECTIVES = {
def is_context_variable(variable) -> bool:
if isinstance(variable, str) and variable.startswith(CONTEXT_PREFIX):
if CONTEXT_REGEX.fullmatch(variable):
return True
else:
raise ValueError(f"Context variable name '{variable}' is not valid.")
return False
return isinstance(variable, str) and CONTEXT_REGEX.fullmatch(variable)
def get_context_value(context_variable: str, **context) -> Any:
@ -116,20 +111,31 @@ def get_context_value(context_variable: str, **context) -> Any:
return value
def resolve_context_variable(param: Union[Any, List[Any]], **context):
def resolve_any_context_variables(
param: Union[Any, List[Any], Dict[Any, Any]], **context
):
if isinstance(param, list):
return [resolve_context_variable(item, **context) for item in param]
elif is_context_variable(param):
return get_context_value(context_variable=param, **context)
return [resolve_any_context_variables(item, **context) for item in param]
elif isinstance(param, dict):
result = {}
for k, v in param.items():
result[k] = resolve_any_context_variables(v, **context)
return result
elif isinstance(param, str):
# either it is a context variable OR contains a context variable within it
# TODO separating the two cases for now out of concern of regex searching
# within strings (case 2)
if is_context_variable(param):
return get_context_value(context_variable=param, **context)
else:
matches = re.findall(CONTEXT_REGEX, param)
for context_var in matches:
# checking out of concern for faulty regex search within string
if context_var in context:
resolved_var = get_context_value(
context_variable=context_var, **context
)
param = param.replace(context_var, str(resolved_var))
return param
else:
return param
def resolve_parameter_context_variables(parameters: Optional[List[Any]], **context):
if not parameters:
processed_parameters = [] # produce empty list
else:
processed_parameters = [
resolve_context_variable(param, **context) for param in parameters
]
return processed_parameters

View File

@ -31,7 +31,7 @@ from nucypher.policy.conditions.base import (
)
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_parameter_context_variables,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
NoConnectionToChain,
@ -169,9 +169,11 @@ class RPCCall(ExecutionCall):
yield provider
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
resolved_parameters = resolve_parameter_context_variables(
self.parameters, **context
)
resolved_parameters = []
if self.parameters:
resolved_parameters = resolve_any_context_variables(
self.parameters, **context
)
endpoints = self._next_endpoint(providers=providers)
latest_error = ""

View File

@ -29,7 +29,7 @@ from nucypher.policy.conditions.base import (
)
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_context_variable,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
@ -611,7 +611,7 @@ class ReturnValueTest:
return result
def with_resolved_context(self, **context):
value = resolve_context_variable(self.value, **context)
value = resolve_any_context_variables(self.value, **context)
return ReturnValueTest(self.comparator, value=value, index=self.index)

View File

@ -1,6 +1,5 @@
import copy
import itertools
import re
import pytest
@ -10,8 +9,7 @@ from nucypher.policy.conditions.context import (
_resolve_user_address,
get_context_value,
is_context_variable,
resolve_context_variable,
resolve_parameter_context_variables,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
@ -67,16 +65,12 @@ def test_is_context_variable():
assert not is_context_variable(variable)
for variable in INVALID_CONTEXT_PARAM_NAMES:
expected_message = re.escape(
f"Context variable name '{variable}' is not valid."
)
with pytest.raises(ValueError, match=expected_message):
_ = is_context_variable(variable)
assert not is_context_variable(variable)
def test_resolve_context_variable():
for value, resolution in VALUES_WITH_RESOLUTION:
assert resolution == resolve_context_variable(value, **CONTEXT)
assert resolution == resolve_any_context_variables(value, **CONTEXT)
def test_resolve_any_context_variables():
@ -86,7 +80,7 @@ def test_resolve_any_context_variables():
params, resolved_params = params_with_resolution
value, resolved_value = value_with_resolution
return_value_test = ReturnValueTest(comparator="==", value=value)
resolved_parameters = resolve_parameter_context_variables([params], **CONTEXT)
resolved_parameters = resolve_any_context_variables([params], **CONTEXT)
resolved_return_value = return_value_test.with_resolved_context(**CONTEXT)
assert resolved_parameters == [resolved_params]
assert resolved_return_value.comparator == return_value_test.comparator
@ -94,6 +88,78 @@ def test_resolve_any_context_variables():
assert resolved_return_value.value == resolved_value
@pytest.mark.parametrize(
"value, expected_resolution",
[
(
"https://api.github.com/user/:foo/:bar",
"https://api.github.com/user/1234/BAR",
),
(
"The cost of :bar is $:foo; $:foo is too expensive for :bar",
"The cost of BAR is $1234; $1234 is too expensive for BAR",
),
# graphql query
(
"""{
organization(login: ":bar") {
teams(first: :foo, userLogins: [":bar"]) {
totalCount
edges {
node {
id
name
description
}
}
}
}
}""",
"""{
organization(login: "BAR") {
teams(first: 1234, userLogins: ["BAR"]) {
totalCount
edges {
node {
id
name
description
}
}
}
}
}""",
),
],
)
def test_resolve_context_variable_within_substring(value, expected_resolution):
context = {":foo": 1234, ":bar": "BAR"}
resolved_value = resolve_any_context_variables(value, **context)
assert expected_resolution == resolved_value
@pytest.mark.parametrize(
"value, expected_resolution",
[
(
{
"book_name": ":bar",
"price": "$:foo",
"description": ":bar is a book about foo and bar.",
},
{
"book_name": "BAR",
"price": "$1234",
"description": "BAR is a book about foo and bar.",
},
)
],
)
def test_resolve_context_variable_within_dictionary(value, expected_resolution):
context = {":foo": 1234, ":bar": "BAR"}
resolved_value = resolve_any_context_variables(value, **context)
assert expected_resolution == resolved_value
@pytest.mark.parametrize("expected_entry", ["address", "signature", "typedData"])
@pytest.mark.parametrize(
"context_variable_name, valid_user_address_fixture",

View File

@ -6,7 +6,7 @@ from typing import NamedTuple
import pytest
from hexbytes import HexBytes
from nucypher.policy.conditions.context import resolve_context_variable
from nucypher.policy.conditions.context import resolve_any_context_variables
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
from nucypher.policy.conditions.lingo import ReturnValueTest
@ -150,14 +150,14 @@ def test_return_value_test_with_resolved_context():
resolved = test.with_resolved_context(**context)
assert resolved.comparator == test.comparator
assert resolved.index == test.index
assert resolved.value == resolve_context_variable(test.value, **context)
assert resolved.value == resolve_any_context_variables(test.value, **context)
test = ReturnValueTest(comparator="==", value=[42, ":foo"])
resolved = test.with_resolved_context(**context)
assert resolved.comparator == test.comparator
assert resolved.index == test.index
assert resolved.value == resolve_context_variable(test.value, **context)
assert resolved.value == resolve_any_context_variables(test.value, **context)
def test_return_value_test_integer():