From 40e2c5c0ea7aa94ff4e3cca4019eb4ad88fa62bb Mon Sep 17 00:00:00 2001 From: derekpierre Date: Tue, 30 Aug 2022 10:49:55 -0400 Subject: [PATCH] Initial work to have Porter provide 'context' for condition-based re-encryption when applicable. --- .../control/specifications/fields/base.py | 36 +++++++++- .../utilities/porter/control/interfaces.py | 24 +++---- .../control/specifications/porter_schema.py | 13 ++++ nucypher/utilities/porter/porter.py | 28 +++++--- .../porter/test_porter_specifications.py | 50 +++++++++++--- tests/unit/test_control.py | 52 ++++++++++++++- tests/utils/policy.py | 65 ++++++++++++++----- 7 files changed, 217 insertions(+), 51 deletions(-) diff --git a/nucypher/control/specifications/fields/base.py b/nucypher/control/specifications/fields/base.py index 83f698ffe..e17b10f10 100644 --- a/nucypher/control/specifications/fields/base.py +++ b/nucypher/control/specifications/fields/base.py @@ -14,7 +14,8 @@ You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ -from base64 import b64encode, b64decode +import json +from base64 import b64decode, b64encode import click from marshmallow import fields @@ -71,11 +72,40 @@ class PositiveInteger(Integer): class Base64BytesRepresentation(BaseField, fields.Field): """Serializes/Deserializes any object's byte representation to/from bae64.""" def _serialize(self, value, attr, obj, **kwargs): - value_bytes = value if isinstance(value, bytes) else bytes(value) - return b64encode(value_bytes).decode() + try: + value_bytes = value if isinstance(value, bytes) else bytes(value) + return b64encode(value_bytes).decode() + except Exception as e: + raise InvalidInputData( + f"Provided object type, {type(value)}, is not serializable: {e}" + ) def _deserialize(self, value, attr, data, **kwargs): try: return b64decode(value) except ValueError as e: raise InvalidInputData(f"Could not parse {self.name}: {e}") + + +class Base64JSON(Base64BytesRepresentation): + """Serializes/Deserializes JSON objects as base64 byte representation.""" + + def _serialize(self, value, attr, obj, **kwargs): + try: + value_json = json.dumps(value) + except Exception as e: + raise InvalidInputData( + f"Provided object type, {type(value)}, is not JSON serializable: {e}" + ) + else: + json_base64_bytes = super()._serialize( + value_json.encode(), attr, obj, **kwargs + ) + return json_base64_bytes + + def _deserialize(self, value, attr, data, **kwargs): + json_bytes = super()._deserialize(value, attr, data, **kwargs) + try: + return json.loads(json_bytes) + except Exception as e: + raise InvalidInputData(f"Invalid JSON bytes: {e}") diff --git a/nucypher/utilities/porter/control/interfaces.py b/nucypher/utilities/porter/control/interfaces.py index daeee0fcf..ae9208425 100644 --- a/nucypher/utilities/porter/control/interfaces.py +++ b/nucypher/utilities/porter/control/interfaces.py @@ -14,11 +14,10 @@ You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ -from typing import List, Optional +from typing import Dict, List, Optional from eth_typing import ChecksumAddress - -from nucypher_core import TreasureMap, RetrievalKit +from nucypher_core import RetrievalKit, TreasureMap from nucypher_core.umbral import PublicKey from nucypher.control.interfaces import ControlInterface, attach_schema @@ -59,12 +58,15 @@ class PorterInterface(ControlInterface): alice_verifying_key: PublicKey, bob_encrypting_key: PublicKey, bob_verifying_key: PublicKey, - ) -> dict: - retrieval_results = self.implementer.retrieve_cfrags(treasure_map=treasure_map, - retrieval_kits=retrieval_kits, - alice_verifying_key=alice_verifying_key, - bob_encrypting_key=bob_encrypting_key, - bob_verifying_key=bob_verifying_key) - results = retrieval_results # list of RetrievalResult objects - response_data = {'retrieval_results': results} + context: Optional[Dict] = None) -> dict: + retrieval_results = self.implementer.retrieve_cfrags( + treasure_map=treasure_map, + retrieval_kits=retrieval_kits, + alice_verifying_key=alice_verifying_key, + bob_encrypting_key=bob_encrypting_key, + bob_verifying_key=bob_verifying_key, + context=context, + ) + results = retrieval_results # list of RetrievalResult objects + response_data = {"retrieval_results": results} return response_data diff --git a/nucypher/utilities/porter/control/specifications/porter_schema.py b/nucypher/utilities/porter/control/specifications/porter_schema.py index 65fcfb400..6d28ebb6d 100644 --- a/nucypher/utilities/porter/control/specifications/porter_schema.py +++ b/nucypher/utilities/porter/control/specifications/porter_schema.py @@ -159,5 +159,18 @@ class BobRetrieveCFrags(BaseSchema): type=click.STRING, required=True)) + # optional + context = base_fields.Base64JSON( + required=False, + load_only=True, + click=click.option( + "--context", + "-ctx", + help="Context data for retrieval conditions", + type=click.STRING, + required=False, + ), + ) + # output retrieval_results = marshmallow_fields.List(marshmallow_fields.Nested(fields.RetrievalResultSchema), dump_only=True) diff --git a/nucypher/utilities/porter/porter.py b/nucypher/utilities/porter/porter.py index 8b1c9602b..f89b85e0b 100644 --- a/nucypher/utilities/porter/porter.py +++ b/nucypher/utilities/porter/porter.py @@ -17,18 +17,21 @@ from pathlib import Path -from typing import List, NamedTuple, Optional, Sequence +from typing import Dict, List, NamedTuple, Optional, Sequence from constant_sorrow.constants import NO_BLOCKCHAIN_CONNECTION, NO_CONTROL_PROTOCOL from eth_typing import ChecksumAddress from eth_utils import to_checksum_address -from flask import request, Response -from nucypher_core import TreasureMap, RetrievalKit +from flask import Response, request +from nucypher_core import RetrievalKit, TreasureMap from nucypher_core.umbral import PublicKey from nucypher.blockchain.eth.agents import ContractAgency, PREApplicationAgent from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory -from nucypher.blockchain.eth.registry import BaseContractRegistry, InMemoryContractRegistry +from nucypher.blockchain.eth.registry import ( + BaseContractRegistry, + InMemoryContractRegistry, +) from nucypher.characters.lawful import Ursula from nucypher.control.controllers import JSONRPCController, WebController from nucypher.crypto.powers import DecryptingPower @@ -36,9 +39,9 @@ from nucypher.network.nodes import Learner from nucypher.network.retrieval import RetrievalClient from nucypher.policy.kits import RetrievalResult from nucypher.policy.reservoir import ( - make_federated_staker_reservoir, + PrefetchStrategy, make_decentralized_staking_provider_reservoir, - PrefetchStrategy + make_federated_staker_reservoir, ) from nucypher.utilities.concurrency import WorkerPool from nucypher.utilities.logging import Logger @@ -164,10 +167,17 @@ the Pipe for PRE Application network operations alice_verifying_key: PublicKey, bob_encrypting_key: PublicKey, bob_verifying_key: PublicKey, - ) -> List[RetrievalResult]: + context: Optional[Dict] = None) -> List[RetrievalResult]: client = RetrievalClient(self) - return client.retrieve_cfrags(treasure_map, retrieval_kits, - alice_verifying_key, bob_encrypting_key, bob_verifying_key) + context = context or dict() # must not be None + return client.retrieve_cfrags( + treasure_map, + retrieval_kits, + alice_verifying_key, + bob_encrypting_key, + bob_verifying_key, + **context, + ) def _make_reservoir(self, quantity: int, diff --git a/tests/integration/porter/test_porter_specifications.py b/tests/integration/porter/test_porter_specifications.py index a39075426..d70e4b263 100644 --- a/tests/integration/porter/test_porter_specifications.py +++ b/tests/integration/porter/test_porter_specifications.py @@ -15,17 +15,23 @@ along with nucypher. If not, see . """ import random +from base64 import b64encode import pytest - from nucypher_core.umbral import SecretKey from nucypher.characters.control.specifications.fields import Key -from nucypher.control.specifications.exceptions import InvalidArgumentCombo, InvalidInputData -from nucypher.utilities.porter.control.specifications.fields import UrsulaInfoSchema, RetrievalResultSchema +from nucypher.control.specifications.exceptions import ( + InvalidArgumentCombo, + InvalidInputData, +) +from nucypher.utilities.porter.control.specifications.fields import ( + RetrievalResultSchema, + UrsulaInfoSchema, +) from nucypher.utilities.porter.control.specifications.porter_schema import ( AliceGetUrsulas, - BobRetrieveCFrags + BobRetrieveCFrags, ) from nucypher.utilities.porter.porter import Porter from tests.utils.policy import retrieval_request_setup @@ -170,13 +176,34 @@ def test_bob_retrieve_cfrags(federated_porter, with pytest.raises(InvalidInputData): bob_retrieve_cfrags_schema.load({}) - # Setup + # Setup - no context retrieval_args, _ = retrieval_request_setup(enacted_federated_policy, federated_bob, federated_alice, encode_for_rest=True) bob_retrieve_cfrags_schema.load(retrieval_args) + # simple schema load w/ optional context + context = { + "domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"}, + "message": { + "address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29", + "conditions": b64encode( + "random condition for reencryption".encode() + ).decode(), + "blockNumber": 15440685, + "blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd", + }, + } + retrieval_args, _ = retrieval_request_setup( + enacted_federated_policy, + federated_bob, + federated_alice, + encode_for_rest=True, + context=context, + ) + bob_retrieve_cfrags_schema.load(retrieval_args) + # missing required argument updated_data = dict(retrieval_args) key_to_remove = random.choice(list(updated_data.keys())) @@ -186,12 +213,15 @@ def test_bob_retrieve_cfrags(federated_porter, bob_retrieve_cfrags_schema.load(updated_data) # - # Output i.e. dump + # Actual retrieval output # - non_encoded_retrieval_args, _ = retrieval_request_setup(enacted_federated_policy, - federated_bob, - federated_alice, - encode_for_rest=False) + non_encoded_retrieval_args, _ = retrieval_request_setup( + enacted_federated_policy, + federated_bob, + federated_alice, + encode_for_rest=False, + context=context, + ) retrieval_results = federated_porter.retrieve_cfrags(**non_encoded_retrieval_args) expected_retrieval_results_json = [] retrieval_result_schema = RetrievalResultSchema() diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py index 298bbbeac..e5e89f460 100644 --- a/tests/unit/test_control.py +++ b/tests/unit/test_control.py @@ -14,12 +14,19 @@ GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ +import json from base64 import b64encode import pytest from nucypher.control.specifications.exceptions import InvalidInputData -from nucypher.control.specifications.fields import PositiveInteger, StringList, String, Base64BytesRepresentation +from nucypher.control.specifications.fields import ( + Base64BytesRepresentation, + Base64JSON, + PositiveInteger, + String, + StringList, +) def test_positive_integer_field(): @@ -58,6 +65,49 @@ def test_base64_representation_field(): deserialized = field._deserialize(value=serialized, attr=None, data=None) assert deserialized == data + with pytest.raises(InvalidInputData): + # attempt to serialize a non-serializable object + field._serialize(value=Exception("non-serializable"), attr=None, obj=None) + with pytest.raises(InvalidInputData): # attempt to deserialize none base64 data field._deserialize(value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None) + + +def test_base64_json_field(): + # test data + dict_data = { + "domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"}, + "message": { + "address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29", + "conditions": b64encode( + "random condition for reencryption".encode() + ).decode(), + "blockNumber": 15440685, + "blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd", + }, + } + list_data = [12.5, 1.2, 4.3] + str_data = "Everything in the universe has a rhythm, everything dances." # -- Maya Angelou + num_data = 1234567890 + bool_data = True + + # test serialization/deserialization of data + test_data = [dict_data, list_data, str_data, num_data, bool_data] + field = Base64JSON() + for d in test_data: + serialized = field._serialize(value=d, attr=None, obj=None) + assert serialized == b64encode(json.dumps(d).encode()).decode() + + deserialized = field._deserialize(value=serialized, attr=None, data=None) + assert deserialized == d + + with pytest.raises(InvalidInputData): + # attempt to serialize non-json serializable object + field._serialize(value=Exception("non-serializable"), attr=None, obj=None) + + with pytest.raises(InvalidInputData): + # attempt to deserialize invalid data + field._deserialize( + value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None + ) diff --git a/tests/utils/policy.py b/tests/utils/policy.py index 2ac1f0b4b..f51d974e6 100644 --- a/tests/utils/policy.py +++ b/tests/utils/policy.py @@ -18,14 +18,17 @@ along with nucypher. If not, see . import os import random import string -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from nucypher_core import MessageKit, RetrievalKit from nucypher.characters.control.specifications.fields import Key, TreasureMap from nucypher.characters.lawful import Enrico +from nucypher.control.specifications.fields import Base64JSON from nucypher.crypto.powers import DecryptingPower -from nucypher.utilities.porter.control.specifications.fields import RetrievalKit as RetrievalKitField +from nucypher.utilities.porter.control.specifications.fields import ( + RetrievalKit as RetrievalKitField, +) def generate_random_label() -> bytes: @@ -41,9 +44,15 @@ def generate_random_label() -> bytes: return bytes(random_label, encoding='utf-8') -def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes = None, encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]: - treasure_map = bob._decrypt_treasure_map(enacted_policy.treasure_map, - enacted_policy.publisher_verifying_key) +def retrieval_request_setup(enacted_policy, + bob, + alice, + original_message: Optional[bytes] = None, + context: Optional[Dict] = None, + encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]: + treasure_map = bob._decrypt_treasure_map( + enacted_policy.treasure_map, enacted_policy.publisher_verifying_key + ) # We pick up our story with Bob already having followed the treasure map above, ie: bob.start_learning_loop() @@ -56,18 +65,40 @@ def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes encode_bytes = (lambda field, obj: field()._serialize(value=obj, attr=None, obj=None)) if encode_for_rest else (lambda field, obj: obj) - return (dict(treasure_map=encode_bytes(TreasureMap, treasure_map), - retrieval_kits=[encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit))], - alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()), - bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)), - bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey())), - message_kit) + retrieval_params = dict( + treasure_map=encode_bytes(TreasureMap, treasure_map), + retrieval_kits=[ + encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit)) + ], + alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()), + bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)), + bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey()), + ) + # context is optional + if context: + retrieval_params["context"] = encode_bytes(Base64JSON, context) + + return retrieval_params, message_kit def retrieval_params_decode_from_rest(retrieval_params: Dict) -> Dict: - decode_bytes = (lambda field, data: field()._deserialize(value=data, attr=None, data=None)) - return dict(treasure_map=decode_bytes(TreasureMap, retrieval_params['treasure_map']), - retrieval_kits=[decode_bytes(RetrievalKitField, kit) for kit in retrieval_params['retrieval_kits']], - alice_verifying_key=decode_bytes(Key, retrieval_params['alice_verifying_key']), - bob_encrypting_key=decode_bytes(Key, retrieval_params['bob_encrypting_key']), - bob_verifying_key=decode_bytes(Key, retrieval_params['bob_verifying_key'])) + decode_bytes = lambda field, data: field()._deserialize( + value=data, attr=None, data=None + ) + decoded_params = dict( + treasure_map=decode_bytes(TreasureMap, retrieval_params["treasure_map"]), + retrieval_kits=[ + decode_bytes(RetrievalKitField, kit) + for kit in retrieval_params["retrieval_kits"] + ], + alice_verifying_key=decode_bytes(Key, retrieval_params["alice_verifying_key"]), + bob_encrypting_key=decode_bytes(Key, retrieval_params["bob_encrypting_key"]), + bob_verifying_key=decode_bytes(Key, retrieval_params["bob_verifying_key"]), + ) + # context is optional + if "context" in retrieval_params: + decoded_params["context"] = decode_bytes( + Base64JSON, retrieval_params["context"] + ) + + return decoded_params