Localize condition lingo handling/processing to one function - `evaluate_condition_lingo`.

pull/3140/head
derekpierre 2023-06-11 09:28:34 -04:00
parent 4214a9b009
commit 45a420a9db
5 changed files with 97 additions and 80 deletions

View File

@ -26,7 +26,6 @@ from nucypher.crypto.signing import InvalidSignature
from nucypher.network.client import ThresholdAccessControlClient from nucypher.network.client import ThresholdAccessControlClient
from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.policy.conditions.exceptions import InvalidConditionContext from nucypher.policy.conditions.exceptions import InvalidConditionContext
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
from nucypher.policy.kits import RetrievalResult from nucypher.policy.kits import RetrievalResult
@ -158,9 +157,20 @@ class RetrievalWorkOrder:
return [rk.capsule for rk in self.__retrieval_kits] return [rk.capsule for rk in self.__retrieval_kits]
@property @property
def lingos(self) -> Conditions: def conditions(self) -> Conditions:
_lingos = [rk.conditions for rk in self.__retrieval_kits] _conditions_list = [rk.conditions for rk in self.__retrieval_kits]
rust_lingos = _serialize_rust_lingos(_lingos) rust_conditions = self._serialize_rust_conditions(_conditions_list)
return rust_conditions
@staticmethod
def _serialize_rust_conditions(conditions_list: List[Conditions]) -> Conditions:
lingo_lists = list()
for condition in conditions_list:
lingo = condition
if condition:
lingo = json.loads((str(condition)))
lingo_lists.append(lingo)
rust_lingos = Conditions(json.dumps(lingo_lists))
return rust_lingos return rust_lingos
@ -280,7 +290,7 @@ class PRERetrievalClient(ThresholdAccessControlClient):
reencryption_request = ReencryptionRequest( reencryption_request = ReencryptionRequest(
capsules=work_order.capsules, capsules=work_order.capsules,
conditions=work_order.lingos, conditions=work_order.conditions,
context=Context(request_context_string), context=Context(request_context_string),
hrac=treasure_map.hrac, hrac=treasure_map.hrac,
encrypted_kfrag=treasure_map.destinations[work_order.ursula_address], encrypted_kfrag=treasure_map.destinations[work_order.ursula_address],

View File

@ -25,8 +25,6 @@ from nucypher.crypto.signing import InvalidSignature
from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.nodes import NodeSprout from nucypher.network.nodes import NodeSprout
from nucypher.network.protocols import InterfaceInfo from nucypher.network.protocols import InterfaceInfo
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.rust_shims import _deserialize_rust_lingos
from nucypher.policy.conditions.utils import evaluate_condition_lingo from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.utilities.logging import Logger from nucypher.utilities.logging import Logger
@ -156,28 +154,27 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
f"Threshold decryption request for ritual ID #{decryption_request.ritual_id}" f"Threshold decryption request for ritual ID #{decryption_request.ritual_id}"
) )
# requester-supplied condition eval context
context = None
if decryption_request.context:
context = (
json.loads(str(decryption_request.context)) or dict()
) # nucypher_core.Context -> str -> dict
# Deserialize and instantiate ConditionLingo from the request data # Deserialize and instantiate ConditionLingo from the request data
conditions_data = str(decryption_request.conditions) # nucypher_core.Conditions -> str condition_lingo = json.loads(
if not conditions_data: str(decryption_request.conditions)
# TODO is this needed - this should never happen ) # nucypher_core.Conditions -> str -> Lingo
if not condition_lingo:
# TODO is this needed - this should never happen for CBD - defeats the purpose
return Response( return Response(
"No conditions present for ciphertext - invalid for CBD functionality", "No conditions present for ciphertext - invalid for CBD functionality",
status=HTTPStatus.FORBIDDEN, status=HTTPStatus.FORBIDDEN,
) )
# TODO what if this fails i.e. ValidationError with the schema
lingo = ConditionLingo.from_dict(
json.loads(conditions_data)
) # str -> list -> ConditionLingo
# requester-supplied condition eval context
context = None
if decryption_request.context:
context = json.loads(str(decryption_request.context)) or dict() # nucypher_core.Context -> str -> dict
# evaluate the conditions for this ciphertext # evaluate the conditions for this ciphertext
error = evaluate_condition_lingo( error = evaluate_condition_lingo(
lingo=lingo, condition_lingo=condition_lingo,
context=context, context=context,
providers=this_node.condition_providers, providers=this_node.condition_providers,
) )
@ -230,13 +227,15 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
reenc_request = ReencryptionRequest.from_bytes(request.data) reenc_request = ReencryptionRequest.from_bytes(request.data)
# Deserialize and instantiate ConditionLingo from the request data # Deserialize and instantiate ConditionLingo from the request data
lingos = _deserialize_rust_lingos(reenc_request=reenc_request) condition_lingo_list = json.loads(
str(reenc_request.conditions)
) # Conditions -> str -> List[Lingo]
# requester-supplied reencryption condition context # requester-supplied reencryption condition context
context = json.loads(str(reenc_request.context)) or dict() context = json.loads(str(reenc_request.context)) or dict()
# zip capsules with their respective conditions # zip capsules with their respective conditions
packets = zip(reenc_request.capsules, lingos) packets = zip(reenc_request.capsules, condition_lingo_list)
# TODO: Relocate HRAC to RE.context # TODO: Relocate HRAC to RE.context
hrac = reenc_request.hrac hrac = reenc_request.hrac
@ -292,7 +291,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
for capsule, condition_lingo in packets: for capsule, condition_lingo in packets:
if condition_lingo: if condition_lingo:
error = evaluate_condition_lingo( error = evaluate_condition_lingo(
lingo=condition_lingo, condition_lingo=condition_lingo,
providers=this_node.condition_providers, providers=this_node.condition_providers,
context=context context=context
) )

View File

@ -1,25 +0,0 @@
import json
from typing import List
from nucypher_core import ReencryptionRequest, Conditions
from nucypher.policy.conditions.lingo import ConditionLingo
def _serialize_rust_lingos(lingos: List[Conditions]) -> Conditions:
lingo_lists = list()
for lingo in lingos:
if lingo:
lingo = json.loads((str(lingo)))
lingo_lists.append(lingo)
rust_lingos = Conditions(json.dumps(lingo_lists))
return rust_lingos
def _deserialize_rust_lingos(reenc_request: ReencryptionRequest):
"""Shim for nucypher-core lingos"""
json_lingos = json.loads(str(reenc_request.conditions))
lingo = [
ConditionLingo.from_dict(lingo) if lingo else None for lingo in json_lingos
]
return lingo

View File

@ -3,7 +3,7 @@ import re
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union from typing import Dict, NamedTuple, Optional, Tuple, Type, Union
from marshmallow import Schema, post_dump from marshmallow import Schema, ValidationError, post_dump
from web3.providers import BaseProvider from web3.providers import BaseProvider
from nucypher.policy.conditions.exceptions import ( from nucypher.policy.conditions.exceptions import (
@ -104,7 +104,7 @@ def validate_condition_lingo(condition: Lingo) -> None:
def evaluate_condition_lingo( def evaluate_condition_lingo(
lingo: "ConditionLingo", condition_lingo: Lingo,
providers: Optional[Dict[int, BaseProvider]] = None, providers: Optional[Dict[int, BaseProvider]] = None,
context: Optional[ContextDict] = None, context: Optional[ContextDict] = None,
log: Logger = __LOGGER, log: Logger = __LOGGER,
@ -116,6 +116,9 @@ def evaluate_condition_lingo(
# TODO: Evaluate all conditions even if one fails and report the result # TODO: Evaluate all conditions even if one fails and report the result
""" """
# prevent circular import
from nucypher.policy.conditions.lingo import ConditionLingo
# Setup (don't use mutable defaults) # Setup (don't use mutable defaults)
context = context or dict() context = context or dict()
providers = providers or dict() providers = providers or dict()
@ -123,13 +126,23 @@ def evaluate_condition_lingo(
# Evaluate # Evaluate
try: try:
log.info(f"Evaluating access conditions {lingo}") if condition_lingo:
log.info(f"Evaluating access conditions {condition_lingo}")
lingo = ConditionLingo.from_dict(condition_lingo)
result = lingo.eval(providers=providers, **context) result = lingo.eval(providers=providers, **context)
if not result: if not result:
# explicit condition failure # explicit condition failure
error = EvalError( error = EvalError(
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN "Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
) )
except ValidationError as e:
# marshmallow Validation Error
# TODO get this to always be InvalidConditionInfo/InvalidCondition
# so that this block can be removed
error = EvalError(
f"Invalid condition grammar: {e}",
HTTPStatus.BAD_REQUEST,
)
except ReturnValueEvaluationError as e: except ReturnValueEvaluationError as e:
error = EvalError( error = EvalError(
f"Unable to evaluate return value: {e}", f"Unable to evaluate return value: {e}",

View File

@ -17,7 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
from marshmallow import fields from marshmallow import fields
@ -57,8 +57,13 @@ def test_evaluate_condition_exception_cases(
condition_lingo = Mock() condition_lingo = Mock()
condition_lingo.eval.side_effect = exception_class(*exception_constructor_params) condition_lingo.eval.side_effect = exception_class(*exception_constructor_params)
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo( eval_error = evaluate_condition_lingo(
lingo=condition_lingo condition_lingo=condition_lingo
) # provider and context default to empty dicts ) # provider and context default to empty dicts
assert eval_error assert eval_error
assert eval_error.status_code == expected_status_code assert eval_error.status_code == expected_status_code
@ -67,8 +72,14 @@ def test_evaluate_condition_exception_cases(
def test_evaluate_condition_eval_returns_false(): def test_evaluate_condition_eval_returns_false():
condition_lingo = Mock() condition_lingo = Mock()
condition_lingo.eval.return_value = False condition_lingo.eval.return_value = False
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo( eval_error = evaluate_condition_lingo(
lingo=condition_lingo, condition_lingo=condition_lingo,
providers={1: Mock(spec=BaseProvider)}, # fake provider providers={1: Mock(spec=BaseProvider)}, # fake provider
context={"key": "value"}, # fake context context={"key": "value"}, # fake context
) )
@ -79,13 +90,22 @@ def test_evaluate_condition_eval_returns_false():
def test_evaluate_condition_eval_returns_true(): def test_evaluate_condition_eval_returns_true():
condition_lingo = Mock() condition_lingo = Mock()
condition_lingo.eval.return_value = True condition_lingo.eval.return_value = True
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
eval_error = evaluate_condition_lingo( eval_error = evaluate_condition_lingo(
lingo=condition_lingo, condition_lingo=condition_lingo,
providers={ providers={
1: Mock(spec=BaseProvider), 1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider), 2: Mock(spec=BaseProvider),
}, # multiple fake provider }, # multiple fake provider
context={"key1": "value1", "key2": "value2"}, # multiple values in fake context context={
"key1": "value1",
"key2": "value2",
}, # multiple values in fake context
) )
assert eval_error is None assert eval_error is None