RFC: Do not mutate user-supplied condition context.

pull/2986/head
Kieran Prasch 2022-11-09 13:19:40 +00:00
parent 19d3f16635
commit cb434e1a03
5 changed files with 22 additions and 18 deletions

View File

@ -167,10 +167,6 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
# zip capsules with their respective conditions # zip capsules with their respective conditions
packets = zip(reenc_request.capsules, lingo) packets = zip(reenc_request.capsules, lingo)
# Populate default request context for decentralized nodes
if not this_node.federated_only:
context.update({'providers': this_node.condition_providers})
# TODO: Detect if we are dealing with PRE or tDec here # TODO: Detect if we are dealing with PRE or tDec here
# TODO: This is for PRE only, relocate HRAC to RE.context # TODO: This is for PRE only, relocate HRAC to RE.context
hrac = reenc_request.hrac hrac = reenc_request.hrac
@ -217,15 +213,11 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
return Response(message, status=HTTPStatus.BAD_REQUEST) return Response(message, status=HTTPStatus.BAD_REQUEST)
# Enforce Reencryption Conditions # Enforce Reencryption Conditions
providers = this_node.condition_providers if not this_node.federated_only else dict()
capsules_to_process = list() capsules_to_process = list()
for capsule, lingo in packets: for capsule, lingo in packets:
# raises an exception or continues # raises an exception or continues
evaluate_conditions_for_ursula( evaluate_conditions_for_ursula(lingo=lingo, providers=providers, context=context)
lingo=lingo,
context=context,
log=log,
ursula=this_node
)
capsules_to_process.append((lingo, capsule)) capsules_to_process.append((lingo, capsule))
# Strip away conditions that have already been evaluated # Strip away conditions that have already been evaluated

View File

@ -1,9 +1,10 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import Union, Type, Dict from typing import Union, Type, Dict, Optional
from flask import Response from flask import Response
from marshmallow import Schema, post_dump from marshmallow import Schema, post_dump
from web3.providers import BaseProvider
from nucypher.policy.conditions.base import ReencryptionCondition from nucypher.policy.conditions.base import ReencryptionCondition
from nucypher.policy.conditions.context import ( from nucypher.policy.conditions.context import (
@ -11,8 +12,10 @@ from nucypher.policy.conditions.context import (
InvalidContextVariableData, InvalidContextVariableData,
RequiredContextVariable, RequiredContextVariable,
) )
from nucypher.utilities.logging import Logger
_ETH = 'eth_' _ETH = 'eth_'
__LOGGER = Logger('condition-eval')
def to_camelcase(s): def to_camelcase(s):
@ -78,13 +81,21 @@ def _deserialize_condition_lingo(data: Union[str, Dict[str, str]]) -> Union['Ope
return instance return instance
def evaluate_conditions_for_ursula(lingo, context, log, ursula): def evaluate_conditions_for_ursula(lingo: 'ConditionLingo',
providers: Optional[Dict[str, BaseProvider]] = None,
context: Optional[Dict[Union[str, int], Union[str, int]]] = None,
log: Logger = __LOGGER,
) -> Response:
# avoid using a mutable defaults and support federated mode
context = context or dict()
providers = providers or dict()
if lingo is not None: if lingo is not None:
# TODO: Enforce policy expiration as a condition
# TODO: Evaluate all conditions even if one fails and report the result # TODO: Evaluate all conditions even if one fails and report the result
try: try:
log.info(f'Evaluating access conditions {lingo.id}') log.info(f'Evaluating access conditions {lingo.id}')
_results = lingo.eval(**context) _results = lingo.eval(providers=providers, **context)
except ReencryptionCondition.InvalidCondition as e: except ReencryptionCondition.InvalidCondition as e:
message = f"Incorrect value provided for condition: {e}" message = f"Incorrect value provided for condition: {e}"
error = (message, HTTPStatus.BAD_REQUEST) error = (message, HTTPStatus.BAD_REQUEST)

View File

@ -93,9 +93,11 @@ def _resolve_any_context_variables(
def _validate_chain(chain: int): def _validate_chain(chain: int):
if not isinstance(chain, int): if not isinstance(chain, int):
raise ValueError(f'"chain" must be a the integer of a chain ID (got "{chain}").') raise ValueError(f'"The chain" field of c a condition must be the '
f'integer of a chain ID (got "{chain}").')
if chain not in _CONDITION_CHAINS: if chain not in _CONDITION_CHAINS:
raise RPCCondition.InvalidCondition(f'chain ID {chain} is not a permitted blockchain for condition evaluation.') raise RPCCondition.InvalidCondition(f'chain ID {chain} is not a permitted '
f'blockchain for condition evaluation.')
class RPCCondition(ReencryptionCondition): class RPCCondition(ReencryptionCondition):

View File

@ -21,7 +21,7 @@ import base64
import json import json
import operator as pyoperator import operator as pyoperator
from hashlib import md5 from hashlib import md5
from typing import Any, Dict, List, Union, Iterator, Optional from typing import Any, Dict, List, Union, Iterator
from marshmallow import fields, post_load from marshmallow import fields, post_load

View File

@ -412,7 +412,6 @@ def test_single_retrieve_with_onchain_conditions(enacted_blockchain_policy, bloc
"value": "10000000000000" "value": "10000000000000"
} }
} }
] ]
messages, message_kits = _make_message_kits(enacted_blockchain_policy.public_key, conditions) messages, message_kits = _make_message_kits(enacted_blockchain_policy.public_key, conditions)
policy_info_kwargs = dict( policy_info_kwargs = dict(