Initial work to have Porter provide 'context' for condition-based re-encryption when applicable.

pull/2960/head
derekpierre 2022-08-30 10:49:55 -04:00 committed by Kieran Prasch
parent cd11a414bc
commit 40e2c5c0ea
7 changed files with 217 additions and 51 deletions

View File

@ -14,7 +14,8 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from base64 import b64encode, b64decode
import json
from base64 import b64decode, b64encode
import click
from marshmallow import fields
@ -71,11 +72,40 @@ class PositiveInteger(Integer):
class Base64BytesRepresentation(BaseField, fields.Field):
"""Serializes/Deserializes any object's byte representation to/from bae64."""
def _serialize(self, value, attr, obj, **kwargs):
value_bytes = value if isinstance(value, bytes) else bytes(value)
return b64encode(value_bytes).decode()
try:
value_bytes = value if isinstance(value, bytes) else bytes(value)
return b64encode(value_bytes).decode()
except Exception as e:
raise InvalidInputData(
f"Provided object type, {type(value)}, is not serializable: {e}"
)
def _deserialize(self, value, attr, data, **kwargs):
try:
return b64decode(value)
except ValueError as e:
raise InvalidInputData(f"Could not parse {self.name}: {e}")
class Base64JSON(Base64BytesRepresentation):
"""Serializes/Deserializes JSON objects as base64 byte representation."""
def _serialize(self, value, attr, obj, **kwargs):
try:
value_json = json.dumps(value)
except Exception as e:
raise InvalidInputData(
f"Provided object type, {type(value)}, is not JSON serializable: {e}"
)
else:
json_base64_bytes = super()._serialize(
value_json.encode(), attr, obj, **kwargs
)
return json_base64_bytes
def _deserialize(self, value, attr, data, **kwargs):
json_bytes = super()._deserialize(value, attr, data, **kwargs)
try:
return json.loads(json_bytes)
except Exception as e:
raise InvalidInputData(f"Invalid JSON bytes: {e}")

View File

@ -14,11 +14,10 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import List, Optional
from typing import Dict, List, Optional
from eth_typing import ChecksumAddress
from nucypher_core import TreasureMap, RetrievalKit
from nucypher_core import RetrievalKit, TreasureMap
from nucypher_core.umbral import PublicKey
from nucypher.control.interfaces import ControlInterface, attach_schema
@ -59,12 +58,15 @@ class PorterInterface(ControlInterface):
alice_verifying_key: PublicKey,
bob_encrypting_key: PublicKey,
bob_verifying_key: PublicKey,
) -> dict:
retrieval_results = self.implementer.retrieve_cfrags(treasure_map=treasure_map,
retrieval_kits=retrieval_kits,
alice_verifying_key=alice_verifying_key,
bob_encrypting_key=bob_encrypting_key,
bob_verifying_key=bob_verifying_key)
results = retrieval_results # list of RetrievalResult objects
response_data = {'retrieval_results': results}
context: Optional[Dict] = None) -> dict:
retrieval_results = self.implementer.retrieve_cfrags(
treasure_map=treasure_map,
retrieval_kits=retrieval_kits,
alice_verifying_key=alice_verifying_key,
bob_encrypting_key=bob_encrypting_key,
bob_verifying_key=bob_verifying_key,
context=context,
)
results = retrieval_results # list of RetrievalResult objects
response_data = {"retrieval_results": results}
return response_data

View File

@ -159,5 +159,18 @@ class BobRetrieveCFrags(BaseSchema):
type=click.STRING,
required=True))
# optional
context = base_fields.Base64JSON(
required=False,
load_only=True,
click=click.option(
"--context",
"-ctx",
help="Context data for retrieval conditions",
type=click.STRING,
required=False,
),
)
# output
retrieval_results = marshmallow_fields.List(marshmallow_fields.Nested(fields.RetrievalResultSchema), dump_only=True)

View File

@ -17,18 +17,21 @@
from pathlib import Path
from typing import List, NamedTuple, Optional, Sequence
from typing import Dict, List, NamedTuple, Optional, Sequence
from constant_sorrow.constants import NO_BLOCKCHAIN_CONNECTION, NO_CONTROL_PROTOCOL
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from flask import request, Response
from nucypher_core import TreasureMap, RetrievalKit
from flask import Response, request
from nucypher_core import RetrievalKit, TreasureMap
from nucypher_core.umbral import PublicKey
from nucypher.blockchain.eth.agents import ContractAgency, PREApplicationAgent
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import BaseContractRegistry, InMemoryContractRegistry
from nucypher.blockchain.eth.registry import (
BaseContractRegistry,
InMemoryContractRegistry,
)
from nucypher.characters.lawful import Ursula
from nucypher.control.controllers import JSONRPCController, WebController
from nucypher.crypto.powers import DecryptingPower
@ -36,9 +39,9 @@ from nucypher.network.nodes import Learner
from nucypher.network.retrieval import RetrievalClient
from nucypher.policy.kits import RetrievalResult
from nucypher.policy.reservoir import (
make_federated_staker_reservoir,
PrefetchStrategy,
make_decentralized_staking_provider_reservoir,
PrefetchStrategy
make_federated_staker_reservoir,
)
from nucypher.utilities.concurrency import WorkerPool
from nucypher.utilities.logging import Logger
@ -164,10 +167,17 @@ the Pipe for PRE Application network operations
alice_verifying_key: PublicKey,
bob_encrypting_key: PublicKey,
bob_verifying_key: PublicKey,
) -> List[RetrievalResult]:
context: Optional[Dict] = None) -> List[RetrievalResult]:
client = RetrievalClient(self)
return client.retrieve_cfrags(treasure_map, retrieval_kits,
alice_verifying_key, bob_encrypting_key, bob_verifying_key)
context = context or dict() # must not be None
return client.retrieve_cfrags(
treasure_map,
retrieval_kits,
alice_verifying_key,
bob_encrypting_key,
bob_verifying_key,
**context,
)
def _make_reservoir(self,
quantity: int,

View File

@ -15,17 +15,23 @@
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import random
from base64 import b64encode
import pytest
from nucypher_core.umbral import SecretKey
from nucypher.characters.control.specifications.fields import Key
from nucypher.control.specifications.exceptions import InvalidArgumentCombo, InvalidInputData
from nucypher.utilities.porter.control.specifications.fields import UrsulaInfoSchema, RetrievalResultSchema
from nucypher.control.specifications.exceptions import (
InvalidArgumentCombo,
InvalidInputData,
)
from nucypher.utilities.porter.control.specifications.fields import (
RetrievalResultSchema,
UrsulaInfoSchema,
)
from nucypher.utilities.porter.control.specifications.porter_schema import (
AliceGetUrsulas,
BobRetrieveCFrags
BobRetrieveCFrags,
)
from nucypher.utilities.porter.porter import Porter
from tests.utils.policy import retrieval_request_setup
@ -170,13 +176,34 @@ def test_bob_retrieve_cfrags(federated_porter,
with pytest.raises(InvalidInputData):
bob_retrieve_cfrags_schema.load({})
# Setup
# Setup - no context
retrieval_args, _ = retrieval_request_setup(enacted_federated_policy,
federated_bob,
federated_alice,
encode_for_rest=True)
bob_retrieve_cfrags_schema.load(retrieval_args)
# simple schema load w/ optional context
context = {
"domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"},
"message": {
"address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29",
"conditions": b64encode(
"random condition for reencryption".encode()
).decode(),
"blockNumber": 15440685,
"blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd",
},
}
retrieval_args, _ = retrieval_request_setup(
enacted_federated_policy,
federated_bob,
federated_alice,
encode_for_rest=True,
context=context,
)
bob_retrieve_cfrags_schema.load(retrieval_args)
# missing required argument
updated_data = dict(retrieval_args)
key_to_remove = random.choice(list(updated_data.keys()))
@ -186,12 +213,15 @@ def test_bob_retrieve_cfrags(federated_porter,
bob_retrieve_cfrags_schema.load(updated_data)
#
# Output i.e. dump
# Actual retrieval output
#
non_encoded_retrieval_args, _ = retrieval_request_setup(enacted_federated_policy,
federated_bob,
federated_alice,
encode_for_rest=False)
non_encoded_retrieval_args, _ = retrieval_request_setup(
enacted_federated_policy,
federated_bob,
federated_alice,
encode_for_rest=False,
context=context,
)
retrieval_results = federated_porter.retrieve_cfrags(**non_encoded_retrieval_args)
expected_retrieval_results_json = []
retrieval_result_schema = RetrievalResultSchema()

View File

@ -14,12 +14,19 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import json
from base64 import b64encode
import pytest
from nucypher.control.specifications.exceptions import InvalidInputData
from nucypher.control.specifications.fields import PositiveInteger, StringList, String, Base64BytesRepresentation
from nucypher.control.specifications.fields import (
Base64BytesRepresentation,
Base64JSON,
PositiveInteger,
String,
StringList,
)
def test_positive_integer_field():
@ -58,6 +65,49 @@ def test_base64_representation_field():
deserialized = field._deserialize(value=serialized, attr=None, data=None)
assert deserialized == data
with pytest.raises(InvalidInputData):
# attempt to serialize a non-serializable object
field._serialize(value=Exception("non-serializable"), attr=None, obj=None)
with pytest.raises(InvalidInputData):
# attempt to deserialize none base64 data
field._deserialize(value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None)
def test_base64_json_field():
# test data
dict_data = {
"domain": {"name": "tdec", "version": 1, "chainId": 1, "salt": "blahblahblah"},
"message": {
"address": "0x03e75d7dd38cce2e20ffee35ec914c57780a8e29",
"conditions": b64encode(
"random condition for reencryption".encode()
).decode(),
"blockNumber": 15440685,
"blockHash": "0x2220da8b777767df526acffd5375ebb340fc98e53c1040b25ad1a8119829e3bd",
},
}
list_data = [12.5, 1.2, 4.3]
str_data = "Everything in the universe has a rhythm, everything dances." # -- Maya Angelou
num_data = 1234567890
bool_data = True
# test serialization/deserialization of data
test_data = [dict_data, list_data, str_data, num_data, bool_data]
field = Base64JSON()
for d in test_data:
serialized = field._serialize(value=d, attr=None, obj=None)
assert serialized == b64encode(json.dumps(d).encode()).decode()
deserialized = field._deserialize(value=serialized, attr=None, data=None)
assert deserialized == d
with pytest.raises(InvalidInputData):
# attempt to serialize non-json serializable object
field._serialize(value=Exception("non-serializable"), attr=None, obj=None)
with pytest.raises(InvalidInputData):
# attempt to deserialize invalid data
field._deserialize(
value=b"raw bytes with non base64 chars ?&^%", attr=None, data=None
)

View File

@ -18,14 +18,17 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
import os
import random
import string
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
from nucypher_core import MessageKit, RetrievalKit
from nucypher.characters.control.specifications.fields import Key, TreasureMap
from nucypher.characters.lawful import Enrico
from nucypher.control.specifications.fields import Base64JSON
from nucypher.crypto.powers import DecryptingPower
from nucypher.utilities.porter.control.specifications.fields import RetrievalKit as RetrievalKitField
from nucypher.utilities.porter.control.specifications.fields import (
RetrievalKit as RetrievalKitField,
)
def generate_random_label() -> bytes:
@ -41,9 +44,15 @@ def generate_random_label() -> bytes:
return bytes(random_label, encoding='utf-8')
def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes = None, encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]:
treasure_map = bob._decrypt_treasure_map(enacted_policy.treasure_map,
enacted_policy.publisher_verifying_key)
def retrieval_request_setup(enacted_policy,
bob,
alice,
original_message: Optional[bytes] = None,
context: Optional[Dict] = None,
encode_for_rest: bool = False) -> Tuple[Dict, MessageKit]:
treasure_map = bob._decrypt_treasure_map(
enacted_policy.treasure_map, enacted_policy.publisher_verifying_key
)
# We pick up our story with Bob already having followed the treasure map above, ie:
bob.start_learning_loop()
@ -56,18 +65,40 @@ def retrieval_request_setup(enacted_policy, bob, alice, original_message: bytes
encode_bytes = (lambda field, obj: field()._serialize(value=obj, attr=None, obj=None)) if encode_for_rest else (lambda field, obj: obj)
return (dict(treasure_map=encode_bytes(TreasureMap, treasure_map),
retrieval_kits=[encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit))],
alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()),
bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)),
bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey())),
message_kit)
retrieval_params = dict(
treasure_map=encode_bytes(TreasureMap, treasure_map),
retrieval_kits=[
encode_bytes(RetrievalKitField, RetrievalKit.from_message_kit(message_kit))
],
alice_verifying_key=encode_bytes(Key, alice.stamp.as_umbral_pubkey()),
bob_encrypting_key=encode_bytes(Key, bob.public_keys(DecryptingPower)),
bob_verifying_key=encode_bytes(Key, bob.stamp.as_umbral_pubkey()),
)
# context is optional
if context:
retrieval_params["context"] = encode_bytes(Base64JSON, context)
return retrieval_params, message_kit
def retrieval_params_decode_from_rest(retrieval_params: Dict) -> Dict:
decode_bytes = (lambda field, data: field()._deserialize(value=data, attr=None, data=None))
return dict(treasure_map=decode_bytes(TreasureMap, retrieval_params['treasure_map']),
retrieval_kits=[decode_bytes(RetrievalKitField, kit) for kit in retrieval_params['retrieval_kits']],
alice_verifying_key=decode_bytes(Key, retrieval_params['alice_verifying_key']),
bob_encrypting_key=decode_bytes(Key, retrieval_params['bob_encrypting_key']),
bob_verifying_key=decode_bytes(Key, retrieval_params['bob_verifying_key']))
decode_bytes = lambda field, data: field()._deserialize(
value=data, attr=None, data=None
)
decoded_params = dict(
treasure_map=decode_bytes(TreasureMap, retrieval_params["treasure_map"]),
retrieval_kits=[
decode_bytes(RetrievalKitField, kit)
for kit in retrieval_params["retrieval_kits"]
],
alice_verifying_key=decode_bytes(Key, retrieval_params["alice_verifying_key"]),
bob_encrypting_key=decode_bytes(Key, retrieval_params["bob_encrypting_key"]),
bob_verifying_key=decode_bytes(Key, retrieval_params["bob_verifying_key"]),
)
# context is optional
if "context" in retrieval_params:
decoded_params["context"] = decode_bytes(
Base64JSON, retrieval_params["context"]
)
return decoded_params