diff --git a/nucypher/control/specifications/fields/base.py b/nucypher/control/specifications/fields/base.py
index 83f698ffe..e17b10f10 100644
--- a/nucypher/control/specifications/fields/base.py
+++ b/nucypher/control/specifications/fields/base.py
@@ -14,7 +14,8 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see .
"""
-from base64 import b64encode, b64decode
+import json
+from base64 import b64decode, b64encode
import click
from marshmallow import fields
@@ -71,11 +72,40 @@ class PositiveInteger(Integer):
class Base64BytesRepresentation(BaseField, fields.Field):
"""Serializes/Deserializes any object's byte representation to/from bae64."""
def _serialize(self, value, attr, obj, **kwargs):
- value_bytes = value if isinstance(value, bytes) else bytes(value)
- return b64encode(value_bytes).decode()
+ try:
+ value_bytes = value if isinstance(value, bytes) else bytes(value)
+ return b64encode(value_bytes).decode()
+ except Exception as e:
+ raise InvalidInputData(
+ f"Provided object type, {type(value)}, is not serializable: {e}"
+ )
def _deserialize(self, value, attr, data, **kwargs):
try:
return b64decode(value)
except ValueError as e:
raise InvalidInputData(f"Could not parse {self.name}: {e}")
+
+
+class Base64JSON(Base64BytesRepresentation):
+ """Serializes/Deserializes JSON objects as base64 byte representation."""
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ try:
+ value_json = json.dumps(value)
+ except Exception as e:
+ raise InvalidInputData(
+ f"Provided object type, {type(value)}, is not JSON serializable: {e}"
+ )
+ else:
+ json_base64_bytes = super()._serialize(
+ value_json.encode(), attr, obj, **kwargs
+ )
+ return json_base64_bytes
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ json_bytes = super()._deserialize(value, attr, data, **kwargs)
+ try:
+ return json.loads(json_bytes)
+ except Exception as e:
+ raise InvalidInputData(f"Invalid JSON bytes: {e}")
diff --git a/nucypher/utilities/porter/control/interfaces.py b/nucypher/utilities/porter/control/interfaces.py
index daeee0fcf..ae9208425 100644
--- a/nucypher/utilities/porter/control/interfaces.py
+++ b/nucypher/utilities/porter/control/interfaces.py
@@ -14,11 +14,10 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see .
"""
-from typing import List, Optional
+from typing import Dict, List, Optional
from eth_typing import ChecksumAddress
-
-from nucypher_core import TreasureMap, RetrievalKit
+from nucypher_core import RetrievalKit, TreasureMap
from nucypher_core.umbral import PublicKey
from nucypher.control.interfaces import ControlInterface, attach_schema
@@ -59,12 +58,15 @@ class PorterInterface(ControlInterface):
alice_verifying_key: PublicKey,
bob_encrypting_key: PublicKey,
bob_verifying_key: PublicKey,
- ) -> dict:
- retrieval_results = self.implementer.retrieve_cfrags(treasure_map=treasure_map,
- retrieval_kits=retrieval_kits,
- alice_verifying_key=alice_verifying_key,
- bob_encrypting_key=bob_encrypting_key,
- bob_verifying_key=bob_verifying_key)
- results = retrieval_results # list of RetrievalResult objects
- response_data = {'retrieval_results': results}
+ context: Optional[Dict] = None) -> dict:
+ retrieval_results = self.implementer.retrieve_cfrags(
+ treasure_map=treasure_map,
+ retrieval_kits=retrieval_kits,
+ alice_verifying_key=alice_verifying_key,
+ bob_encrypting_key=bob_encrypting_key,
+ bob_verifying_key=bob_verifying_key,
+ context=context,
+ )
+ results = retrieval_results # list of RetrievalResult objects
+ response_data = {"retrieval_results": results}
return response_data
diff --git a/nucypher/utilities/porter/control/specifications/porter_schema.py b/nucypher/utilities/porter/control/specifications/porter_schema.py
index 65fcfb400..6d28ebb6d 100644
--- a/nucypher/utilities/porter/control/specifications/porter_schema.py
+++ b/nucypher/utilities/porter/control/specifications/porter_schema.py
@@ -159,5 +159,18 @@ class BobRetrieveCFrags(BaseSchema):
type=click.STRING,
required=True))
+ # optional
+ context = base_fields.Base64JSON(
+ required=False,
+ load_only=True,
+ click=click.option(
+ "--context",
+ "-ctx",
+ help="Context data for retrieval conditions",
+ type=click.STRING,
+ required=False,
+ ),
+ )
+
# output
retrieval_results = marshmallow_fields.List(marshmallow_fields.Nested(fields.RetrievalResultSchema), dump_only=True)
diff --git a/nucypher/utilities/porter/porter.py b/nucypher/utilities/porter/porter.py
index 8b1c9602b..f89b85e0b 100644
--- a/nucypher/utilities/porter/porter.py
+++ b/nucypher/utilities/porter/porter.py
@@ -17,18 +17,21 @@
from pathlib import Path
-from typing import List, NamedTuple, Optional, Sequence
+from typing import Dict, List, NamedTuple, Optional, Sequence
from constant_sorrow.constants import NO_BLOCKCHAIN_CONNECTION, NO_CONTROL_PROTOCOL
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
-from flask import request, Response
-from nucypher_core import TreasureMap, RetrievalKit
+from flask import Response, request
+from nucypher_core import RetrievalKit, TreasureMap
from nucypher_core.umbral import PublicKey
from nucypher.blockchain.eth.agents import ContractAgency, PREApplicationAgent
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
-from nucypher.blockchain.eth.registry import BaseContractRegistry, InMemoryContractRegistry
+from nucypher.blockchain.eth.registry import (
+ BaseContractRegistry,
+ InMemoryContractRegistry,
+)
from nucypher.characters.lawful import Ursula
from nucypher.control.controllers import JSONRPCController, WebController
from nucypher.crypto.powers import DecryptingPower
@@ -36,9 +39,9 @@ from nucypher.network.nodes import Learner
from nucypher.network.retrieval import RetrievalClient
from nucypher.policy.kits import RetrievalResult
from nucypher.policy.reservoir import (
- make_federated_staker_reservoir,
+ PrefetchStrategy,
make_decentralized_staking_provider_reservoir,
- PrefetchStrategy
+ make_federated_staker_reservoir,
)
from nucypher.utilities.concurrency import WorkerPool
from nucypher.utilities.logging import Logger
@@ -164,10 +167,17 @@ the Pipe for PRE Application network operations
alice_verifying_key: PublicKey,
bob_encrypting_key: PublicKey,
bob_verifying_key: PublicKey,
- ) -> List[RetrievalResult]:
+ context: Optional[Dict] = None) -> List[RetrievalResult]:
client = RetrievalClient(self)
- return client.retrieve_cfrags(treasure_map, retrieval_kits,
- alice_verifying_key, bob_encrypting_key, bob_verifying_key)
+ context = context or dict() # must not be None
+ return client.retrieve_cfrags(
+ treasure_map,
+ retrieval_kits,
+ alice_verifying_key,
+ bob_encrypting_key,
+ bob_verifying_key,
+ **context,
+ )
def _make_reservoir(self,
quantity: int,
diff --git a/tests/integration/porter/test_porter_specifications.py b/tests/integration/porter/test_porter_specifications.py
index a39075426..d70e4b263 100644
--- a/tests/integration/porter/test_porter_specifications.py
+++ b/tests/integration/porter/test_porter_specifications.py
@@ -15,17 +15,23 @@
along with nucypher. If not, see .
"""
import random
+from base64 import b64encode
import pytest
-
from nucypher_core.umbral import SecretKey
from nucypher.characters.control.specifications.fields import Key
-from nucypher.control.specifications.exceptions import InvalidArgumentCombo, InvalidInputData
-from nucypher.utilities.porter.control.specifications.fields import UrsulaInfoSchema, RetrievalResultSchema
+from nucypher.control.specifications.exceptions import (
+ InvalidArgumentCombo,
+ InvalidInputData,
+)
+from nucypher.utilities.porter.control.specifications.fields import (
+ RetrievalResultSchema,
+ UrsulaInfoSchema,
+)
from nucypher.utilities.porter.control.specifications.porter_schema import (
AliceGetUrsulas,
- BobRetrieveCFrags
+ BobRetrieveCFrags,
)
from nucypher.utilities.porter.porter import Porter
from tests.utils.policy import retrieval_request_setup
@@ -170,13 +176,34 @@ def test_bob_retrieve_cfrags(federated_porter,
with pytest.raises(InvalidInputData):
bob_retrieve_cfrags_schema.load({})
- # Setup
+ # Setup - no context
retrieval_args, _ = retrieval_request_setup(enacted_federated_policy,
federated_bob,
federated_alice,
encode_for_rest=True)
bob_retrieve_cfrags_schema.load(retrieval_args)
+ # simple schema load w/ optional context
+ context = {
+ "domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"},
+ "message": {
+ "address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29",
+ "conditions": b64encode(
+ "random condition for reencryption".encode()
+ ).decode(),
+ "blockNumber": 15440685,
+ "blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd",
+ },
+ }
+ retrieval_args, _ = retrieval_request_setup(
+ enacted_federated_policy,
+ federated_bob,
+ federated_alice,
+ encode_for_rest=True,
+ context=context,
+ )
+ bob_retrieve_cfrags_schema.load(retrieval_args)
+
# missing required argument
updated_data = dict(retrieval_args)
key_to_remove = random.choice(list(updated_data.keys()))
@@ -186,12 +213,15 @@ def test_bob_retrieve_cfrags(federated_porter,
bob_retrieve_cfrags_schema.load(updated_data)
#
- # Output i.e. dump
+ # Actual retrieval output
#
- non_encoded_retrieval_args, _ = retrieval_request_setup(enacted_federated_policy,
- federated_bob,
- federated_alice,
- encode_for_rest=False)
+ non_encoded_retrieval_args, _ = retrieval_request_setup(
+ enacted_federated_policy,
+ federated_bob,
+ federated_alice,
+ encode_for_rest=False,
+ context=context,
+ )
retrieval_results = federated_porter.retrieve_cfrags(**non_encoded_retrieval_args)
expected_retrieval_results_json = []
retrieval_result_schema = RetrievalResultSchema()
diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py
index 298bbbeac..e5e89f460 100644
--- a/tests/unit/test_control.py
+++ b/tests/unit/test_control.py
@@ -14,12 +14,19 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see .
"""
+import json
from base64 import b64encode
import pytest
from nucypher.control.specifications.exceptions import InvalidInputData
-from nucypher.control.specifications.fields import PositiveInteger, StringList, String, Base64BytesRepresentation
+from nucypher.control.specifications.fields import (
+ Base64BytesRepresentation,
+ Base64JSON,
+ PositiveInteger,
+ String,
+ StringList,
+)
def test_positive_integer_field():
@@ -58,6 +65,49 @@ def test_base64_representation_field():
deserialized = field._deserialize(value=serialized, attr=None, data=None)
assert deserialized == data
+ with pytest.raises(InvalidInputData):
+ # attempt to serialize a non-serializable object
+ field._serialize(value=Exception("non-serializable"), attr=None, obj=None)
+
with pytest.raises(InvalidInputData):
# attempt to deserialize none base64 data
field._deserialize(value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None)
+
+
+def test_base64_json_field():
+ # test data
+ dict_data = {
+ "domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"},
+ "message": {
+ "address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29",
+ "conditions": b64encode(
+ "random condition for reencryption".encode()
+ ).decode(),
+ "blockNumber": 15440685,
+ "blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd",
+ },
+ }
+ list_data = [12.5, 1.2, 4.3]
+ str_data = "Everything in the universe has a rhythm, everything dances." # -- Maya Angelou
+ num_data = 1234567890
+ bool_data = True
+
+ # test serialization/deserialization of data
+ test_data = [dict_data, list_data, str_data, num_data, bool_data]
+ field = Base64JSON()
+ for d in test_data:
+ serialized = field._serialize(value=d, attr=None, obj=None)
+ assert serialized == b64encode(json.dumps(d).encode()).decode()
+
+ deserialized = field._deserialize(value=serialized, attr=None, data=None)
+ assert deserialized == d
+
+ with pytest.raises(InvalidInputData):
+ # attempt to serialize non-json serializable object
+ field._serialize(value=Exception("non-serializable"), attr=None, obj=None)
+
+ with pytest.raises(InvalidInputData):
+ # attempt to deserialize invalid data
+ field._deserialize(
+ value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None
+ )
diff --git a/tests/utils/policy.py b/tests/utils/policy.py
index 2ac1f0b4b..f51d974e6 100644
--- a/tests/utils/policy.py
+++ b/tests/utils/policy.py
@@ -18,14 +18,17 @@ along with nucypher. If not, see .
import os
import random
import string
-from typing import Dict, Tuple
+from typing import Dict, Optional, Tuple
from nucypher_core import MessageKit, RetrievalKit
from nucypher.characters.control.specifications.fields import Key, TreasureMap
from nucypher.characters.lawful import Enrico
+from nucypher.control.specifications.fields import Base64JSON
from nucypher.crypto.powers import DecryptingPower
-from nucypher.utilities.porter.control.specifications.fields import RetrievalKit as RetrievalKitField
+from nucypher.utilities.porter.control.specifications.fields import (
+ RetrievalKit as RetrievalKitField,
+)
def generate_random_label() -> bytes:
@@ -41,9 +44,15 @@ def generate_random_label() -> bytes:
return bytes(random_label, encoding='utf-8')
-def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes = None, encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]:
- treasure_map = bob._decrypt_treasure_map(enacted_policy.treasure_map,
- enacted_policy.publisher_verifying_key)
+def retrieval_request_setup(enacted_policy,
+ bob,
+ alice,
+ original_message: Optional[bytes] = None,
+ context: Optional[Dict] = None,
+ encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]:
+ treasure_map = bob._decrypt_treasure_map(
+ enacted_policy.treasure_map, enacted_policy.publisher_verifying_key
+ )
# We pick up our story with Bob already having followed the treasure map above, ie:
bob.start_learning_loop()
@@ -56,18 +65,40 @@ def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes
encode_bytes = (lambda field, obj: field()._serialize(value=obj, attr=None, obj=None)) if encode_for_rest else (lambda field, obj: obj)
- return (dict(treasure_map=encode_bytes(TreasureMap, treasure_map),
- retrieval_kits=[encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit))],
- alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()),
- bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)),
- bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey())),
- message_kit)
+ retrieval_params = dict(
+ treasure_map=encode_bytes(TreasureMap, treasure_map),
+ retrieval_kits=[
+ encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit))
+ ],
+ alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()),
+ bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)),
+ bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey()),
+ )
+ # context is optional
+ if context:
+ retrieval_params["context"] = encode_bytes(Base64JSON, context)
+
+ return retrieval_params, message_kit
def retrieval_params_decode_from_rest(retrieval_params: Dict) -> Dict:
- decode_bytes = (lambda field, data: field()._deserialize(value=data, attr=None, data=None))
- return dict(treasure_map=decode_bytes(TreasureMap, retrieval_params['treasure_map']),
- retrieval_kits=[decode_bytes(RetrievalKitField, kit) for kit in retrieval_params['retrieval_kits']],
- alice_verifying_key=decode_bytes(Key, retrieval_params['alice_verifying_key']),
- bob_encrypting_key=decode_bytes(Key, retrieval_params['bob_encrypting_key']),
- bob_verifying_key=decode_bytes(Key, retrieval_params['bob_verifying_key']))
+ decode_bytes = lambda field, data: field()._deserialize(
+ value=data, attr=None, data=None
+ )
+ decoded_params = dict(
+ treasure_map=decode_bytes(TreasureMap, retrieval_params["treasure_map"]),
+ retrieval_kits=[
+ decode_bytes(RetrievalKitField, kit)
+ for kit in retrieval_params["retrieval_kits"]
+ ],
+ alice_verifying_key=decode_bytes(Key, retrieval_params["alice_verifying_key"]),
+ bob_encrypting_key=decode_bytes(Key, retrieval_params["bob_encrypting_key"]),
+ bob_verifying_key=decode_bytes(Key, retrieval_params["bob_verifying_key"]),
+ )
+ # context is optional
+ if "context" in retrieval_params:
+ decoded_params["context"] = decode_bytes(
+ Base64JSON, retrieval_params["context"]
+ )
+
+ return decoded_params