diff --git a/tests/unit/conditions/test_context.py b/tests/unit/conditions/test_context.py index 8dab9ddf7..fef1866d4 100644 --- a/tests/unit/conditions/test_context.py +++ b/tests/unit/conditions/test_context.py @@ -1,28 +1,83 @@ import itertools +import re import pytest -from nucypher.policy.conditions.context import resolve_any_context_variables +from nucypher.policy.conditions.context import ( + _resolve_context_variable, + is_context_variable, + resolve_any_context_variables, +) from nucypher.policy.conditions.lingo import ReturnValueTest INVALID_CONTEXT_PARAM_NAMES = [ ":", ":)", ":!", + ":3", ":superñoño", ":::::this//is 🍌 🍌 🍌 ", ":123 \"$%'+-?\n jarl!! cobarde!!", ] +VALID_CONTEXT_PARAM_NAMES = [ + ":foo", + ":_bar", + ":bar_", + ":_bar_", + ":VAR", + ":a1234", + ":snake_case", + ":camelCase", + ":_", # TODO: not sure if we should allow this one, tbh +] -@pytest.mark.parametrize( - "var1,var2", itertools.product(INVALID_CONTEXT_PARAM_NAMES, repeat=2) -) -def test_invalid_context_parameter(var1, var2): - context = {var1: 42, var2: 42} - # Check that parameters make sense, what about repeated variables? - parameters = [var1, 1, 2] +DEFINITELY_NOT_CONTEXT_PARAM_NAMES = ["1234", "foo", "", 123] - with pytest.raises(ValueError): - return_value_test = ReturnValueTest(comparator="==", value=var2) - _ = resolve_any_context_variables(parameters, return_value_test, **context) +CONTEXT = {":foo": 1234, ":bar": "'BAR'"} + +VALUES_WITH_RESOLUTION = [ + (42, 42), + (True, True), + ("'bar'", "'bar'"), + ([42, True, "'bar'"], [42, True, "'bar'"]), + (":foo", 1234), + ([":foo", True, "'bar'"], [1234, True, "'bar'"]), + ([":foo", ":foo", 5, [99, [":bar"]]], [1234, 1234, 5, [99, ["'BAR'"]]]), +] + + +def test_is_context_variable(): + for variable in VALID_CONTEXT_PARAM_NAMES: + assert is_context_variable(variable) + + for variable in DEFINITELY_NOT_CONTEXT_PARAM_NAMES: + assert not is_context_variable(variable) + + for variable in INVALID_CONTEXT_PARAM_NAMES: + expected_message = re.escape( + f"Context variable name '{variable}' is not valid." + ) + with pytest.raises(ValueError, match=expected_message): + _ = is_context_variable(variable) + + +def test_resolve_context_variable(): + for value, resolution in VALUES_WITH_RESOLUTION: + assert resolution == _resolve_context_variable(value, **CONTEXT) + + +def test_resolve_any_context_variables(): + for params_with_resolution, value_with_resolution in itertools.product( + VALUES_WITH_RESOLUTION, repeat=2 + ): + params, resolved_params = params_with_resolution + value, resolved_value = value_with_resolution + return_value_test = ReturnValueTest(comparator="==", value=value) + resolved_parameters, resolved_return_value = resolve_any_context_variables( + [params], return_value_test, **CONTEXT + ) + assert resolved_parameters == [resolved_params] + assert resolved_return_value.comparator == return_value_test.comparator + assert resolved_return_value.index == return_value_test.index + assert resolved_return_value.value == resolved_value diff --git a/tests/unit/conditions/test_return_value.py b/tests/unit/conditions/test_return_value.py index b07ae3f7a..0498a9707 100644 --- a/tests/unit/conditions/test_return_value.py +++ b/tests/unit/conditions/test_return_value.py @@ -6,6 +6,7 @@ from typing import NamedTuple import pytest from hexbytes import HexBytes +from nucypher.policy.conditions.context import _resolve_context_variable from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError from nucypher.policy.conditions.lingo import ReturnValueTest @@ -142,6 +143,23 @@ def test_return_value_test_with_context_variable_cant_run_eval(): test.eval(0) +def test_return_value_test_with_resolved_context(): + test = ReturnValueTest(comparator="==", value=":foo") + context = {":foo": 1234} + + resolved = test.with_resolved_context(**context) + assert resolved.comparator == test.comparator + assert resolved.index == test.index + assert resolved.value == _resolve_context_variable(test.value, **context) + + test = ReturnValueTest(comparator="==", value=[42, ":foo"]) + + resolved = test.with_resolved_context(**context) + assert resolved.comparator == test.comparator + assert resolved.index == test.index + assert resolved.value == _resolve_context_variable(test.value, **context) + + def test_return_value_test_integer(): # > test = ReturnValueTest(comparator='>', value='0')