Always raise ConditionEvalError (previously EvalError) instead of returning then raising.

Update tests.
pull/3360/head
derekpierre 2023-12-01 13:59:03 -05:00
parent c3d126f522
commit 363eb3975c
4 changed files with 55 additions and 51 deletions

View File

@ -649,14 +649,12 @@ class Operator(BaseActor):
"No conditions present for ciphertext.",
)
# evaluate the conditions for this ciphertext
error = evaluate_condition_lingo(
# evaluate the conditions for this ciphertext; raises if it fails
evaluate_condition_lingo(
condition_lingo=condition_lingo,
context=context,
providers=self.condition_providers,
)
if error:
raise error
def _verify_decryption_request_authorization(
self, decryption_request: ThresholdDecryptionRequest

View File

@ -21,7 +21,10 @@ from nucypher.crypto.keypairs import DecryptingKeypair
from nucypher.crypto.signing import InvalidSignature
from nucypher.network.nodes import NodeSprout
from nucypher.network.protocols import InterfaceInfo
from nucypher.policy.conditions.utils import EvalError, evaluate_condition_lingo
from nucypher.policy.conditions.utils import (
ConditionEvalError,
evaluate_condition_lingo,
)
from nucypher.utilities.logging import Logger
HERE = BASE_DIR = Path(__file__).parent
@ -164,7 +167,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
return Response("Ritual not found", status=HTTPStatus.NOT_FOUND)
except this_node.UnauthorizedRequest as e:
return Response(str(e), status=HTTPStatus.UNAUTHORIZED)
except EvalError as e:
except ConditionEvalError as e:
return Response(e.message, status=e.status_code)
except this_node.DecryptionFailure as e:
return Response(str(e), status=HTTPStatus.INTERNAL_SERVER_ERROR)
@ -245,12 +248,13 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
capsules_to_process = list()
for capsule, condition_lingo in packets:
if condition_lingo:
error = evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=this_node.condition_providers,
context=context
)
if error:
try:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=this_node.condition_providers,
context=context,
)
except ConditionEvalError as error:
# TODO: This response short-circuits the entire request on falsy condition
# even if other unrelated capsules (message kits) are present.
return Response(error.message, status=error.status_code)

View File

@ -1,6 +1,6 @@
import re
from http import HTTPStatus
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Set, Tuple
from marshmallow import Schema, post_dump
from web3.providers import BaseProvider
@ -21,7 +21,8 @@ from nucypher.utilities.logging import Logger
__LOGGER = Logger("condition-eval")
class EvalError(Exception):
class ConditionEvalError(Exception):
"""Exception when execution condition evaluation."""
def __init__(self, message: str, status_code: int):
self.message = message
self.status_code = status_code
@ -56,10 +57,10 @@ class CamelCaseSchema(Schema):
def evaluate_condition_lingo(
condition_lingo: Lingo,
providers: Optional[Dict[int, BaseProvider]] = None,
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
context: Optional[ContextDict] = None,
log: Logger = __LOGGER,
) -> Optional[EvalError]:
):
"""
Evaluates condition lingo with the give providers and user supplied context.
If all conditions are satisfied this function returns None.
@ -83,44 +84,46 @@ def evaluate_condition_lingo(
result = lingo.eval(providers=providers, **context)
if not result:
# explicit condition failure
error = EvalError(
error = ConditionEvalError(
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
)
except ReturnValueEvaluationError as e:
error = EvalError(
error = ConditionEvalError(
f"Unable to evaluate return value: {e}",
HTTPStatus.BAD_REQUEST,
)
except InvalidConditionLingo as e:
error = EvalError(
error = ConditionEvalError(
f"Invalid condition grammar: {e}",
HTTPStatus.BAD_REQUEST,
)
except InvalidCondition as e:
error = EvalError(
error = ConditionEvalError(
f"Incorrect value provided for condition: {e}",
HTTPStatus.BAD_REQUEST,
)
except RequiredContextVariable as e:
# TODO: be more specific and name the missing inputs, etc
error = EvalError(f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
error = ConditionEvalError(
f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST
)
except InvalidContextVariableData as e:
error = EvalError(
error = ConditionEvalError(
f"Invalid data provided for context variable: {e}",
HTTPStatus.BAD_REQUEST,
)
except ContextVariableVerificationFailed as e:
error = EvalError(
error = ConditionEvalError(
f"Context variable data could not be verified: {e}",
HTTPStatus.FORBIDDEN,
)
except NoConnectionToChain as e:
error = EvalError(
error = ConditionEvalError(
f"Node does not have a connection to chain ID {e.chain}",
HTTPStatus.NOT_IMPLEMENTED,
)
except ConditionEvaluationFailed as e:
error = EvalError(
error = ConditionEvalError(
f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST
)
except Exception as e:
@ -129,10 +132,9 @@ def evaluate_condition_lingo(
f"Unexpected exception while evaluating "
f"decryption condition ({e.__class__.__name__}): {e}"
)
error = EvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
error = ConditionEvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
log.warn(message)
if error:
log.info(error.message) # log error message
return error
raise error

View File

@ -36,6 +36,7 @@ from nucypher.policy.conditions.exceptions import (
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
ConditionEvalError,
camel_case_to_snake,
evaluate_condition_lingo,
to_camelcase,
@ -70,22 +71,23 @@ def test_evaluate_condition_exception_cases(
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
condition_lingo=condition_lingo
) # provider and context default to empty dicts
assert eval_error
assert eval_error.status_code == expected_status_code
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo
) # provider and context default to empty dicts
assert eval_error.value.status_code == expected_status_code
def test_evaluate_condition_invalid_lingo():
eval_error = evaluate_condition_lingo(
condition_lingo={
"version": ConditionLingo.VERSION,
"condition": {"dont_mind_me": "nothing_to_see_here"},
}
) # provider and context default to empty dicts
assert "Invalid condition grammar" in eval_error.message
assert eval_error.status_code == HTTPStatus.BAD_REQUEST
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo={
"version": ConditionLingo.VERSION,
"condition": {"dont_mind_me": "nothing_to_see_here"},
}
) # provider and context default to empty dicts
assert "Invalid condition grammar" in eval_error.value.message
assert eval_error.value.status_code == HTTPStatus.BAD_REQUEST
def test_evaluate_condition_eval_returns_false():
@ -97,13 +99,13 @@ def test_evaluate_condition_eval_returns_false():
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
context={"key": "value"}, # fake context
)
assert eval_error
assert eval_error.status_code == HTTPStatus.FORBIDDEN
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
context={"key": "value"}, # fake context
)
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
def test_evaluate_condition_eval_returns_true():
@ -115,7 +117,7 @@ def test_evaluate_condition_eval_returns_true():
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={
1: Mock(spec=BaseProvider),
@ -127,8 +129,6 @@ def test_evaluate_condition_eval_returns_true():
}, # multiple values in fake context
)
assert eval_error is None
@pytest.mark.parametrize(
"test_case",