Localize condition lingo handling/processing to one function - `evaluate_condition_lingo`.

pull/3140/head
derekpierre 2023-06-11 09:28:34 -04:00
parent 4214a9b009
commit 45a420a9db
5 changed files with 97 additions and 80 deletions

View File

@ -26,7 +26,6 @@ from nucypher.crypto.signing import InvalidSignature
from nucypher.network.client import ThresholdAccessControlClient
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.policy.conditions.exceptions import InvalidConditionContext
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
from nucypher.policy.kits import RetrievalResult
@ -158,9 +157,20 @@ class RetrievalWorkOrder:
return [rk.capsule for rk in self.__retrieval_kits]
@property
def lingos(self) -> Conditions:
_lingos = [rk.conditions for rk in self.__retrieval_kits]
rust_lingos = _serialize_rust_lingos(_lingos)
def conditions(self) -> Conditions:
_conditions_list = [rk.conditions for rk in self.__retrieval_kits]
rust_conditions = self._serialize_rust_conditions(_conditions_list)
return rust_conditions
@staticmethod
def _serialize_rust_conditions(conditions_list: List[Conditions]) -> Conditions:
lingo_lists = list()
for condition in conditions_list:
lingo = condition
if condition:
lingo = json.loads((str(condition)))
lingo_lists.append(lingo)
rust_lingos = Conditions(json.dumps(lingo_lists))
return rust_lingos
@ -280,7 +290,7 @@ class PRERetrievalClient(ThresholdAccessControlClient):
reencryption_request = ReencryptionRequest(
capsules=work_order.capsules,
conditions=work_order.lingos,
conditions=work_order.conditions,
context=Context(request_context_string),
hrac=treasure_map.hrac,
encrypted_kfrag=treasure_map.destinations[work_order.ursula_address],

View File

@ -25,8 +25,6 @@ from nucypher.crypto.signing import InvalidSignature
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.nodes import NodeSprout
from nucypher.network.protocols import InterfaceInfo
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.rust_shims import _deserialize_rust_lingos
from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.utilities.logging import Logger
@ -156,28 +154,27 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
f"Threshold decryption request for ritual ID #{decryption_request.ritual_id}"
)
# requester-supplied condition eval context
context = None
if decryption_request.context:
context = (
json.loads(str(decryption_request.context)) or dict()
) # nucypher_core.Context -> str -> dict
# Deserialize and instantiate ConditionLingo from the request data
conditions_data = str(decryption_request.conditions) # nucypher_core.Conditions -> str
if not conditions_data:
# TODO is this needed - this should never happen
condition_lingo = json.loads(
str(decryption_request.conditions)
) # nucypher_core.Conditions -> str -> Lingo
if not condition_lingo:
# TODO is this needed - this should never happen for CBD - defeats the purpose
return Response(
"No conditions present for ciphertext - invalid for CBD functionality",
status=HTTPStatus.FORBIDDEN,
)
# TODO what if this fails i.e. ValidationError with the schema
lingo = ConditionLingo.from_dict(
json.loads(conditions_data)
) # str -> list -> ConditionLingo
# requester-supplied condition eval context
context = None
if decryption_request.context:
context = json.loads(str(decryption_request.context)) or dict() # nucypher_core.Context -> str -> dict
# evaluate the conditions for this ciphertext
error = evaluate_condition_lingo(
lingo=lingo,
condition_lingo=condition_lingo,
context=context,
providers=this_node.condition_providers,
)
@ -230,13 +227,15 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
reenc_request = ReencryptionRequest.from_bytes(request.data)
# Deserialize and instantiate ConditionLingo from the request data
lingos = _deserialize_rust_lingos(reenc_request=reenc_request)
condition_lingo_list = json.loads(
str(reenc_request.conditions)
) # Conditions -> str -> List[Lingo]
# requester-supplied reencryption condition context
context = json.loads(str(reenc_request.context)) or dict()
# zip capsules with their respective conditions
packets = zip(reenc_request.capsules, lingos)
packets = zip(reenc_request.capsules, condition_lingo_list)
# TODO: Relocate HRAC to RE.context
hrac = reenc_request.hrac
@ -292,7 +291,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
for capsule, condition_lingo in packets:
if condition_lingo:
error = evaluate_condition_lingo(
lingo=condition_lingo,
condition_lingo=condition_lingo,
providers=this_node.condition_providers,
context=context
)

View File

@ -1,25 +0,0 @@
import json
from typing import List
from nucypher_core import ReencryptionRequest, Conditions
from nucypher.policy.conditions.lingo import ConditionLingo
def _serialize_rust_lingos(lingos: List[Conditions]) -> Conditions:
lingo_lists = list()
for lingo in lingos:
if lingo:
lingo = json.loads((str(lingo)))
lingo_lists.append(lingo)
rust_lingos = Conditions(json.dumps(lingo_lists))
return rust_lingos
def _deserialize_rust_lingos(reenc_request: ReencryptionRequest):
"""Shim for nucypher-core lingos"""
json_lingos = json.loads(str(reenc_request.conditions))
lingo = [
ConditionLingo.from_dict(lingo) if lingo else None for lingo in json_lingos
]
return lingo

View File

@ -3,7 +3,7 @@ import re
from http import HTTPStatus
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union
from marshmallow import Schema, post_dump
from marshmallow import Schema, ValidationError, post_dump
from web3.providers import BaseProvider
from nucypher.policy.conditions.exceptions import (
@ -104,7 +104,7 @@ def validate_condition_lingo(condition: Lingo) -> None:
def evaluate_condition_lingo(
lingo: "ConditionLingo",
condition_lingo: Lingo,
providers: Optional[Dict[int, BaseProvider]] = None,
context: Optional[ContextDict] = None,
log: Logger = __LOGGER,
@ -116,6 +116,9 @@ def evaluate_condition_lingo(
# TODO: Evaluate all conditions even if one fails and report the result
"""
# prevent circular import
from nucypher.policy.conditions.lingo import ConditionLingo
# Setup (don't use mutable defaults)
context = context or dict()
providers = providers or dict()
@ -123,13 +126,23 @@ def evaluate_condition_lingo(
# Evaluate
try:
log.info(f"Evaluating access conditions {lingo}")
result = lingo.eval(providers=providers, **context)
if not result:
# explicit condition failure
error = EvalError(
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
)
if condition_lingo:
log.info(f"Evaluating access conditions {condition_lingo}")
lingo = ConditionLingo.from_dict(condition_lingo)
result = lingo.eval(providers=providers, **context)
if not result:
# explicit condition failure
error = EvalError(
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
)
except ValidationError as e:
# marshmallow Validation Error
# TODO get this to always be InvalidConditionInfo/InvalidCondition
# so that this block can be removed
error = EvalError(
f"Invalid condition grammar: {e}",
HTTPStatus.BAD_REQUEST,
)
except ReturnValueEvaluationError as e:
error = EvalError(
f"Unable to evaluate return value: {e}",

View File

@ -17,7 +17,7 @@
from dataclasses import dataclass
from http import HTTPStatus
from typing import List, Optional, Tuple, Type
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from marshmallow import fields
@ -57,38 +57,58 @@ def test_evaluate_condition_exception_cases(
condition_lingo = Mock()
condition_lingo.eval.side_effect = exception_class(*exception_constructor_params)
eval_error = evaluate_condition_lingo(
lingo=condition_lingo
) # provider and context default to empty dicts
assert eval_error
assert eval_error.status_code == expected_status_code
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
condition_lingo=condition_lingo
) # provider and context default to empty dicts
assert eval_error
assert eval_error.status_code == expected_status_code
def test_evaluate_condition_eval_returns_false():
condition_lingo = Mock()
condition_lingo.eval.return_value = False
eval_error = evaluate_condition_lingo(
lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
context={"key": "value"}, # fake context
)
assert eval_error
assert eval_error.status_code == HTTPStatus.FORBIDDEN
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider
context={"key": "value"}, # fake context
)
assert eval_error
assert eval_error.status_code == HTTPStatus.FORBIDDEN
def test_evaluate_condition_eval_returns_true():
condition_lingo = Mock()
condition_lingo.eval.return_value = True
eval_error = evaluate_condition_lingo(
lingo=condition_lingo,
providers={
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}, # multiple fake provider
context={"key1": "value1", "key2": "value2"}, # multiple values in fake context
)
assert eval_error is None
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers={
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}, # multiple fake provider
context={
"key1": "value1",
"key2": "value2",
}, # multiple values in fake context
)
assert eval_error is None
@pytest.mark.parametrize(