Resolve circular dependency with ReturnValueTest

pull/3344/head
David Núñez 2023-11-07 11:03:08 +01:00
parent b43eb85cca
commit c52a41b0a8
2 changed files with 9 additions and 12 deletions

View File

@ -10,7 +10,6 @@ from nucypher.policy.conditions.exceptions import (
InvalidContextVariableData, InvalidContextVariableData,
RequiredContextVariable, RequiredContextVariable,
) )
from nucypher.policy.conditions.lingo import ReturnValueTest
USER_ADDRESS_CONTEXT = ":userAddress" USER_ADDRESS_CONTEXT = ":userAddress"
@ -97,9 +96,7 @@ def get_context_value(context_variable: str, **context) -> Any:
return value return value
def resolve_any_context_variables( def resolve_any_context_variables(parameters: List[Any], return_value_test, **context):
parameters: List[Any], return_value_test: ReturnValueTest, **context
):
processed_parameters = [] processed_parameters = []
for p in parameters: for p in parameters:
# TODO needs additional support for ERC1155 which has lists of values # TODO needs additional support for ERC1155 which has lists of values
@ -108,12 +105,6 @@ def resolve_any_context_variables(
p = get_context_value(context_variable=p, **context) p = get_context_value(context_variable=p, **context)
processed_parameters.append(p) processed_parameters.append(p)
v = return_value_test.value processed_return_value_test = return_value_test.with_resolved_context(**context)
if is_context_variable(v):
v = get_context_value(context_variable=v, **context)
i = return_value_test.index
processed_return_value_test = ReturnValueTest(
return_value_test.comparator, value=v, index=i
)
return processed_parameters, processed_return_value_test return processed_parameters, processed_return_value_test

View File

@ -21,7 +21,7 @@ from marshmallow.validate import OneOf, Range
from packaging.version import parse as parse_version from packaging.version import parse as parse_version
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
from nucypher.policy.conditions.context import is_context_variable from nucypher.policy.conditions.context import get_context_value, is_context_variable
from nucypher.policy.conditions.exceptions import ( from nucypher.policy.conditions.exceptions import (
InvalidCondition, InvalidCondition,
InvalidConditionLingo, InvalidConditionLingo,
@ -315,6 +315,12 @@ class ReturnValueTest:
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand) result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
return result return result
def with_resolved_context(self, **context):
value = self.value
if is_context_variable(value):
value = get_context_value(context_variable=value, **context)
return ReturnValueTest(self.comparator, value=value, index=self.index)
class ConditionLingo(_Serializable): class ConditionLingo(_Serializable):
VERSION = "1.0.0" VERSION = "1.0.0"