mirror of https://github.com/nucypher/nucypher.git
Localize condition lingo handling/processing to one function - `evaluate_condition_lingo`.
parent
4214a9b009
commit
45a420a9db
|
@ -26,7 +26,6 @@ from nucypher.crypto.signing import InvalidSignature
|
|||
from nucypher.network.client import ThresholdAccessControlClient
|
||||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||
from nucypher.policy.conditions.exceptions import InvalidConditionContext
|
||||
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
|
||||
from nucypher.policy.kits import RetrievalResult
|
||||
|
||||
|
||||
|
@ -158,9 +157,20 @@ class RetrievalWorkOrder:
|
|||
return [rk.capsule for rk in self.__retrieval_kits]
|
||||
|
||||
@property
|
||||
def lingos(self) -> Conditions:
|
||||
_lingos = [rk.conditions for rk in self.__retrieval_kits]
|
||||
rust_lingos = _serialize_rust_lingos(_lingos)
|
||||
def conditions(self) -> Conditions:
|
||||
_conditions_list = [rk.conditions for rk in self.__retrieval_kits]
|
||||
rust_conditions = self._serialize_rust_conditions(_conditions_list)
|
||||
return rust_conditions
|
||||
|
||||
@staticmethod
|
||||
def _serialize_rust_conditions(conditions_list: List[Conditions]) -> Conditions:
|
||||
lingo_lists = list()
|
||||
for condition in conditions_list:
|
||||
lingo = condition
|
||||
if condition:
|
||||
lingo = json.loads((str(condition)))
|
||||
lingo_lists.append(lingo)
|
||||
rust_lingos = Conditions(json.dumps(lingo_lists))
|
||||
return rust_lingos
|
||||
|
||||
|
||||
|
@ -280,7 +290,7 @@ class PRERetrievalClient(ThresholdAccessControlClient):
|
|||
|
||||
reencryption_request = ReencryptionRequest(
|
||||
capsules=work_order.capsules,
|
||||
conditions=work_order.lingos,
|
||||
conditions=work_order.conditions,
|
||||
context=Context(request_context_string),
|
||||
hrac=treasure_map.hrac,
|
||||
encrypted_kfrag=treasure_map.destinations[work_order.ursula_address],
|
||||
|
|
|
@ -25,8 +25,6 @@ from nucypher.crypto.signing import InvalidSignature
|
|||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||
from nucypher.network.nodes import NodeSprout
|
||||
from nucypher.network.protocols import InterfaceInfo
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
from nucypher.policy.conditions.rust_shims import _deserialize_rust_lingos
|
||||
from nucypher.policy.conditions.utils import evaluate_condition_lingo
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
@ -156,28 +154,27 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
f"Threshold decryption request for ritual ID #{decryption_request.ritual_id}"
|
||||
)
|
||||
|
||||
# requester-supplied condition eval context
|
||||
context = None
|
||||
if decryption_request.context:
|
||||
context = (
|
||||
json.loads(str(decryption_request.context)) or dict()
|
||||
) # nucypher_core.Context -> str -> dict
|
||||
|
||||
# Deserialize and instantiate ConditionLingo from the request data
|
||||
conditions_data = str(decryption_request.conditions) # nucypher_core.Conditions -> str
|
||||
if not conditions_data:
|
||||
# TODO is this needed - this should never happen
|
||||
condition_lingo = json.loads(
|
||||
str(decryption_request.conditions)
|
||||
) # nucypher_core.Conditions -> str -> Lingo
|
||||
if not condition_lingo:
|
||||
# TODO is this needed - this should never happen for CBD - defeats the purpose
|
||||
return Response(
|
||||
"No conditions present for ciphertext - invalid for CBD functionality",
|
||||
status=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
# TODO what if this fails i.e. ValidationError with the schema
|
||||
lingo = ConditionLingo.from_dict(
|
||||
json.loads(conditions_data)
|
||||
) # str -> list -> ConditionLingo
|
||||
|
||||
# requester-supplied condition eval context
|
||||
context = None
|
||||
if decryption_request.context:
|
||||
context = json.loads(str(decryption_request.context)) or dict() # nucypher_core.Context -> str -> dict
|
||||
|
||||
# evaluate the conditions for this ciphertext
|
||||
error = evaluate_condition_lingo(
|
||||
lingo=lingo,
|
||||
condition_lingo=condition_lingo,
|
||||
context=context,
|
||||
providers=this_node.condition_providers,
|
||||
)
|
||||
|
@ -230,13 +227,15 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
reenc_request = ReencryptionRequest.from_bytes(request.data)
|
||||
|
||||
# Deserialize and instantiate ConditionLingo from the request data
|
||||
lingos = _deserialize_rust_lingos(reenc_request=reenc_request)
|
||||
condition_lingo_list = json.loads(
|
||||
str(reenc_request.conditions)
|
||||
) # Conditions -> str -> List[Lingo]
|
||||
|
||||
# requester-supplied reencryption condition context
|
||||
context = json.loads(str(reenc_request.context)) or dict()
|
||||
|
||||
# zip capsules with their respective conditions
|
||||
packets = zip(reenc_request.capsules, lingos)
|
||||
packets = zip(reenc_request.capsules, condition_lingo_list)
|
||||
|
||||
# TODO: Relocate HRAC to RE.context
|
||||
hrac = reenc_request.hrac
|
||||
|
@ -292,7 +291,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
for capsule, condition_lingo in packets:
|
||||
if condition_lingo:
|
||||
error = evaluate_condition_lingo(
|
||||
lingo=condition_lingo,
|
||||
condition_lingo=condition_lingo,
|
||||
providers=this_node.condition_providers,
|
||||
context=context
|
||||
)
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
from nucypher_core import ReencryptionRequest, Conditions
|
||||
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
|
||||
|
||||
def _serialize_rust_lingos(lingos: List[Conditions]) -> Conditions:
|
||||
lingo_lists = list()
|
||||
for lingo in lingos:
|
||||
if lingo:
|
||||
lingo = json.loads((str(lingo)))
|
||||
lingo_lists.append(lingo)
|
||||
rust_lingos = Conditions(json.dumps(lingo_lists))
|
||||
return rust_lingos
|
||||
|
||||
|
||||
def _deserialize_rust_lingos(reenc_request: ReencryptionRequest):
|
||||
"""Shim for nucypher-core lingos"""
|
||||
json_lingos = json.loads(str(reenc_request.conditions))
|
||||
lingo = [
|
||||
ConditionLingo.from_dict(lingo) if lingo else None for lingo in json_lingos
|
||||
]
|
||||
return lingo
|
|
@ -3,7 +3,7 @@ import re
|
|||
from http import HTTPStatus
|
||||
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from marshmallow import Schema, ValidationError, post_dump
|
||||
from web3.providers import BaseProvider
|
||||
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
|
@ -104,7 +104,7 @@ def validate_condition_lingo(condition: Lingo) -> None:
|
|||
|
||||
|
||||
def evaluate_condition_lingo(
|
||||
lingo: "ConditionLingo",
|
||||
condition_lingo: Lingo,
|
||||
providers: Optional[Dict[int, BaseProvider]] = None,
|
||||
context: Optional[ContextDict] = None,
|
||||
log: Logger = __LOGGER,
|
||||
|
@ -116,6 +116,9 @@ def evaluate_condition_lingo(
|
|||
# TODO: Evaluate all conditions even if one fails and report the result
|
||||
"""
|
||||
|
||||
# prevent circular import
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo
|
||||
|
||||
# Setup (don't use mutable defaults)
|
||||
context = context or dict()
|
||||
providers = providers or dict()
|
||||
|
@ -123,13 +126,23 @@ def evaluate_condition_lingo(
|
|||
|
||||
# Evaluate
|
||||
try:
|
||||
log.info(f"Evaluating access conditions {lingo}")
|
||||
result = lingo.eval(providers=providers, **context)
|
||||
if not result:
|
||||
# explicit condition failure
|
||||
error = EvalError(
|
||||
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
|
||||
)
|
||||
if condition_lingo:
|
||||
log.info(f"Evaluating access conditions {condition_lingo}")
|
||||
lingo = ConditionLingo.from_dict(condition_lingo)
|
||||
result = lingo.eval(providers=providers, **context)
|
||||
if not result:
|
||||
# explicit condition failure
|
||||
error = EvalError(
|
||||
"Decryption conditions not satisfied", HTTPStatus.FORBIDDEN
|
||||
)
|
||||
except ValidationError as e:
|
||||
# marshmallow Validation Error
|
||||
# TODO get this to always be InvalidConditionInfo/InvalidCondition
|
||||
# so that this block can be removed
|
||||
error = EvalError(
|
||||
f"Invalid condition grammar: {e}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
except ReturnValueEvaluationError as e:
|
||||
error = EvalError(
|
||||
f"Unable to evaluate return value: {e}",
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from marshmallow import fields
|
||||
|
@ -57,38 +57,58 @@ def test_evaluate_condition_exception_cases(
|
|||
condition_lingo = Mock()
|
||||
condition_lingo.eval.side_effect = exception_class(*exception_constructor_params)
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
lingo=condition_lingo
|
||||
) # provider and context default to empty dicts
|
||||
assert eval_error
|
||||
assert eval_error.status_code == expected_status_code
|
||||
with patch(
|
||||
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
|
||||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo
|
||||
) # provider and context default to empty dicts
|
||||
assert eval_error
|
||||
assert eval_error.status_code == expected_status_code
|
||||
|
||||
|
||||
def test_evaluate_condition_eval_returns_false():
|
||||
condition_lingo = Mock()
|
||||
condition_lingo.eval.return_value = False
|
||||
eval_error = evaluate_condition_lingo(
|
||||
lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error
|
||||
assert eval_error.status_code == HTTPStatus.FORBIDDEN
|
||||
|
||||
with patch(
|
||||
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
|
||||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={1: Mock(spec=BaseProvider)}, # fake provider
|
||||
context={"key": "value"}, # fake context
|
||||
)
|
||||
assert eval_error
|
||||
assert eval_error.status_code == HTTPStatus.FORBIDDEN
|
||||
|
||||
|
||||
def test_evaluate_condition_eval_returns_true():
|
||||
condition_lingo = Mock()
|
||||
condition_lingo.eval.return_value = True
|
||||
eval_error = evaluate_condition_lingo(
|
||||
lingo=condition_lingo,
|
||||
providers={
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}, # multiple fake provider
|
||||
context={"key1": "value1", "key2": "value2"}, # multiple values in fake context
|
||||
)
|
||||
|
||||
assert eval_error is None
|
||||
with patch(
|
||||
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
|
||||
) as mocked_from_dict:
|
||||
mocked_from_dict.return_value = condition_lingo
|
||||
|
||||
eval_error = evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers={
|
||||
1: Mock(spec=BaseProvider),
|
||||
2: Mock(spec=BaseProvider),
|
||||
}, # multiple fake provider
|
||||
context={
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}, # multiple values in fake context
|
||||
)
|
||||
|
||||
assert eval_error is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue