mirror of https://github.com/nucypher/nucypher.git
Always raise ConditionEvalError (previously EvalError) instead of returning then raising.
Update tests.pull/3360/head
parent
c3d126f522
commit
363eb3975c
|
@ -649,14 +649,12 @@ class Operator(BaseActor):
|
|||
"No conditions present for ciphertext.",
|
||||
)
|
||||
|
||||
# evaluate the conditions for this ciphertext
|
||||
error = evaluate_condition_lingo(
|
||||
# evaluate the conditions for this ciphertext; raises if it fails
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
context=context,
|
||||
providers=self.condition_providers,
|
||||
)
|
||||
if error:
|
||||
raise error
|
||||
|
||||
def _verify_decryption_request_authorization(
|
||||
self, decryption_request: ThresholdDecryptionRequest
|
||||
|
|
|
@ -21,7 +21,10 @@ from nucypher.crypto.keypairs import DecryptingKeypair
|
|||
from nucypher.crypto.signing import InvalidSignature
|
||||
from nucypher.network.nodes import NodeSprout
|
||||
from nucypher.network.protocols import InterfaceInfo
|
||||
from nucypher.policy.conditions.utils import EvalError, evaluate_condition_lingo
|
||||
from nucypher.policy.conditions.utils import (
|
||||
ConditionEvalError,
|
||||
evaluate_condition_lingo,
|
||||
)
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
HERE = BASE_DIR = Path(__file__).parent
|
||||
|
@ -164,7 +167,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
return Response("Ritual not found", status=HTTPStatus.NOT_FOUND)
|
||||
except this_node.UnauthorizedRequest as e:
|
||||
return Response(str(e), status=HTTPStatus.UNAUTHORIZED)
|
||||
except EvalError as e:
|
||||
except ConditionEvalError as e:
|
||||
return Response(e.message, status=e.status_code)
|
||||
except this_node.DecryptionFailure as e:
|
||||
return Response(str(e), status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
@ -245,12 +248,13 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
capsules_to_process = list()
|
||||
for capsule, condition_lingo in packets:
|
||||
if condition_lingo:
|
||||
error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers=this_node.condition_providers,
|
||||
context=context
|
||||
)
|
||||
if error:
|
||||
try:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers=this_node.condition_providers,
|
||||
context=context,
|
||||
)
|
||||
except ConditionEvalError as error:
|
||||
# TODO: This response short-circuits the entire request on falsy condition
|
||||
# even if other unrelated capsules (message kits) are present.
|
||||
return Response(error.message, status=error.status_code)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Set, Tuple
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from web3.providers import BaseProvider
|
||||
|
@ -21,7 +21,8 @@ from nucypher.utilities.logging import Logger
|
|||
__LOGGER = Logger("condition-eval")
|
||||
|
||||
|
||||
class EvalError(Exception):
|
||||
class ConditionEvalError(Exception):
|
||||
"""Exception when execution condition evaluation."""
|
||||
def __init__(self, message: str, status_code: int):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
@ -56,10 +57,10 @@ class CamelCaseSchema(Schema):
|
|||
|
||||
def evaluate_condition_lingo(
|
||||
condition_lingo: Lingo,
|
||||
providers: Optional[Dict[int, BaseProvider]] = None,
|
||||
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
|
||||
context: Optional[ContextDict] = None,
|
||||
log: Logger = __LOGGER,
|
||||
) -> Optional[EvalError]:
|
||||
):
|
||||
"""
|
||||
Evaluates condition lingo with the give providers and user supplied context.
|
||||
If all conditions are satisfied this function returns None.
|
||||
|
@ -83,44 +84,46 @@ def evaluate_condition_lingo(
|
|||
result = lingo.eval(providers=providers, **context)
|
||||
if not result:
|
||||
# explicit condition failure
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
|
||||
)
|
||||
except ReturnValueEvaluationError as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Unable to evaluate return value: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except InvalidConditionLingo as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Invalid condition grammar: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except InvalidCondition as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Incorrect value provided for condition: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except RequiredContextVariable as e:
|
||||
# TODO: be more specific and name the missing inputs, etc
|
||||
error = EvalError(f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST)
|
||||
error = ConditionEvalError(
|
||||
f"Missing required inputs: {e}", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
except InvalidContextVariableData as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Invalid data provided for context variable: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except ContextVariableVerificationFailed as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Context variable data could not be verified: {e}",
|
||||
HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
except NoConnectionToChain as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Node does not have a connection to chain ID {e.chain}",
|
||||
HTTPStatus.NOT_IMPLEMENTED,
|
||||
)
|
||||
except ConditionEvaluationFailed as e:
|
||||
error = EvalError(
|
||||
error = ConditionEvalError(
|
||||
f"Decryption condition not evaluated: {e}", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -129,10 +132,9 @@ def evaluate_condition_lingo(
|
|||
f"Unexpected exception while evaluating "
|
||||
f"decryption condition ({e.__class__.__name__}): {e}"
|
||||
)
|
||||
error = EvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
error = ConditionEvalError(message, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
log.warn(message)
|
||||
|
||||
if error:
|
||||
log.info(error.message) # log error message
|
||||
|
||||
return error
|
||||
raise error
|
||||
|
|
|
@ -36,6 +36,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
from nucypher.policy.conditions.utils import (
|
||||
CamelCaseSchema,
|
||||
ConditionEvalError,
|
||||
camel_case_to_snake,
|
||||
evaluate_condition_lingo,
|
||||
to_camelcase,
|
||||
|
@ -70,22 +71,23 @@ def test_evaluate_condition_exception_cases(
|
|||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo
|
||||
) # provider and context default to empty dicts
|
||||
assert eval_error
|
||||
assert eval_error.status_code == expected_status_code
|
||||
with pytest.raises(ConditionEvalError) as eval_error:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo
|
||||
) # provider and context default to empty dicts
|
||||
assert eval_error.value.status_code == expected_status_code
|
||||
|
||||
|
||||
def test_evaluate_condition_invalid_lingo():
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo={
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {"dont_mind_me": "nothing_to_see_here"},
|
||||
}
|
||||
) # provider and context default to empty dicts
|
||||
assert "Invalid condition grammar" in eval_error.message
|
||||
assert eval_error.status_code == HTTPStatus.BAD_REQUEST
|
||||
with pytest.raises(ConditionEvalError) as eval_error:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo={
|
||||
"version": ConditionLingo.VERSION,
|
||||
"condition": {"dont_mind_me": "nothing_to_see_here"},
|
||||
}
|
||||
) # provider and context default to empty dicts
|
||||
assert "Invalid condition grammar" in eval_error.value.message
|
||||
assert eval_error.value.status_code == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
def test_evaluate_condition_eval_returns_false():
|
||||
|
@ -97,13 +99,13 @@ def test_evaluate_condition_eval_returns_false():
|
|||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error
|
||||
assert eval_error.status_code == HTTPStatus.FORBIDDEN
|
||||
with pytest.raises(ConditionEvalError) as eval_error:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
|
||||
|
||||
|
||||
def test_evaluate_condition_eval_returns_true():
|
||||
|
@ -115,7 +117,7 @@ def test_evaluate_condition_eval_returns_true():
|
|||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={
|
||||
1: Mock(spec=BaseProvider),
|
||||
|
@ -127,8 +129,6 @@ def test_evaluate_condition_eval_returns_true():
|
|||
}, # multiple values in fake context
|
||||
)
|
||||
|
||||
assert eval_error is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
|
|
Loading…
Reference in New Issue