Merge pull request #3126 from derekpierre/bob-porter-conjoined

Ensure that Bob and Porter can use common code for making Threshold Decryption Requests
pull/3133/head
Derek Pierre 2023-05-23 10:07:54 -04:00 committed by GitHub
commit 8b205ad443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 430 additions and 96 deletions

View File

View File

@ -97,11 +97,12 @@ from nucypher.crypto.powers import (
TLSHostingPower,
TransactingPower,
)
from nucypher.network.decryption import ThresholdDecryptionClient
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.middleware import RestMiddleware
from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher
from nucypher.network.protocols import parse_node_uri
from nucypher.network.retrieval import RetrievalClient
from nucypher.network.retrieval import PRERetrievalClient
from nucypher.network.server import ProxyRESTServer, make_rest_app
from nucypher.network.trackers import AvailabilityTracker
from nucypher.policy.conditions.types import LingoList
@ -502,7 +503,7 @@ class Bob(Character):
retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits]
# Retrieve capsule frags
client = RetrievalClient(learner=self)
client = PRERetrievalClient(learner=self)
retrieval_results, _ = client.retrieve_cfrags(
treasure_map=treasure_map,
retrieval_kits=retrieval_kits,
@ -570,51 +571,46 @@ class Bob(Character):
threshold: int,
variant: FerveoVariant,
context: Optional[dict] = None,
) -> List[DecryptionShareSimple]:
) -> Dict[
ChecksumAddress, Union[DecryptionShareSimple, DecryptionSharePrecomputed]
]:
if variant == FerveoVariant.PRECOMPUTED:
share_type = DecryptionSharePrecomputed
elif variant == FerveoVariant.SIMPLE:
share_type = DecryptionShareSimple
gathered_shares = list()
decryption_request_mapping = {}
for ursula in cohort:
conditions = Conditions(json.dumps(lingo))
if context:
context = Context(json.dumps(context))
decryption_request = ThresholdDecryptionRequest(
id=ritual_id,
variant=int(variant.value),
ciphertext=bytes(ciphertext),
conditions=conditions,
context=context,
variant=int(variant.value),
)
decryption_request_mapping[
to_checksum_address(ursula.checksum_address)
] = bytes(decryption_request)
try:
response = self.network_middleware.get_decryption_share(ursula, bytes(decryption_request))
except NodeSeemsToBeDown as e:
self.log.warn(f"Node {ursula} is unreachable. {e}")
continue
if response.status_code != 200:
self.log.warn(f"Node {ursula} returned {response.status_code}.")
continue
decryption_client = ThresholdDecryptionClient(learner=self)
successes, failures = decryption_client.gather_encrypted_decryption_shares(
encrypted_requests=decryption_request_mapping, threshold=threshold
)
decryption_response = ThresholdDecryptionResponse.from_bytes(
response.content
)
decryption_share = share_type.from_bytes(
decryption_response.decryption_share
)
gathered_shares.append(decryption_share)
self.log.debug(f"Got {len(gathered_shares)}/{threshold} shares so far...")
if variant == FerveoVariant.SIMPLE and (len(gathered_shares) == threshold):
# security threshold reached
break
if len(gathered_shares) < threshold:
if len(successes) < threshold:
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt")
self.log.debug(f"Got enough shares to decrypt.")
gathered_shares = {}
for provider_address, response_bytes in successes.items():
decryption_response = ThresholdDecryptionResponse.from_bytes(response_bytes)
decryption_share = share_type.from_bytes(
decryption_response.decryption_share
)
gathered_shares[provider_address] = decryption_share
return gathered_shares
def threshold_decrypt(self,
@ -647,8 +643,12 @@ class Bob(Character):
except AttributeError:
raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}")
threshold = (ritual.shares // 2) + 1 # TODO: #3095 get this from the ritual / put it on-chain?
shares = self.gather_decryption_shares(
threshold = (
(ritual.shares // 2) + 1
if variant == FerveoVariant.SIMPLE
else ritual.shares
) # TODO: #3095 get this from the ritual / put it on-chain?
decryption_shares = self.gather_decryption_shares(
ritual_id=ritual_id,
cohort=ursulas,
ciphertext=ciphertext,
@ -662,10 +662,15 @@ class Bob(Character):
# TODO: Bob can call.verify here instead of aggregating the shares.
# if the DKG parameters are not provided, we need to
# aggregate the transcripts and derive them.
# TODO we don't need all ursulas, only threshold of them
# ursulas = [u for u in ursulas if u.checksum_address in decryption_shares]
params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold)
# TODO: compare the results with the on-chain records (Coordinator).
return self.__decrypt(shares, ciphertext, conditions, params, variant)
return self.__decrypt(
list(decryption_shares.values()), ciphertext, conditions, params, variant
)
@staticmethod
def __decrypt(

View File

@ -1 +0,0 @@

View File

@ -0,0 +1,50 @@
from typing import List
from eth_typing import ChecksumAddress
from nucypher.network.nodes import Learner
from nucypher.utilities.logging import Logger
class ThresholdAccessControlClient:
"""
Client for communicating with access control nodes on the Threshold Network.
"""
def __init__(self, learner: Learner):
self._learner = learner
self.log = Logger(self.__class__.__name__)
def _ensure_ursula_availability(
self, ursulas: List[ChecksumAddress], threshold: int, timeout=10
):
"""
Make sure we know enough nodes from the treasure map to decrypt;
otherwise block and wait for them to come online.
"""
# OK, so we're going to need to do some network activity for this retrieval.
# Let's make sure we've seeded.
if not self._learner.done_seeding:
self._learner.learn_from_teacher_node()
all_known_ursulas = self._learner.known_nodes.addresses()
# Push all unknown Ursulas from the map in the queue for learning
unknown_ursulas = ursulas - all_known_ursulas
# If we know enough to decrypt, we can proceed.
known_ursulas = ursulas & all_known_ursulas
if len(known_ursulas) >= threshold:
return
# | <--- shares ---> |
# | <--- threshold ---> | <--- allow_missing ---> |
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
allow_missing = len(ursulas) - threshold
self._learner.block_until_specific_nodes_are_known(
unknown_ursulas,
timeout=timeout,
allow_missing=allow_missing,
learn_on_this_thread=True,
)

View File

@ -0,0 +1,74 @@
from typing import Dict, List, Tuple
from eth_typing import ChecksumAddress
from nucypher.network.client import ThresholdAccessControlClient
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
class ThresholdDecryptionClient(ThresholdAccessControlClient):
class DecryptionRequestFailed(Exception):
"""Raised when a decryption request returns a non-zero status code."""
class DecryptionRequestFactory(BatchValueFactory):
def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int):
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
super().__init__(values=ursula_to_contact, required_successes=threshold)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def gather_encrypted_decryption_shares(
self,
encrypted_requests: Dict[ChecksumAddress, bytes],
threshold: int,
timeout: float = 10,
) -> Tuple[Dict[ChecksumAddress, bytes], Dict[ChecksumAddress, str]]:
self._ensure_ursula_availability(
ursulas=list(encrypted_requests.keys()),
threshold=threshold,
timeout=timeout,
)
def worker(ursula_address: ChecksumAddress) -> bytes:
encrypted_request = encrypted_requests[ursula_address]
try:
node_or_sprout = self._learner.known_nodes[ursula_address]
node_or_sprout.mature()
response = (
self._learner.network_middleware.get_encrypted_decryption_share(
node_or_sprout, encrypted_request
)
)
except Exception as e:
self.log.warn(f"Node {ursula_address} raised {e}")
raise
else:
if response.status_code != 200:
message = f"Node {ursula_address} returned {response.status_code} - {response.content}."
self.log.warn(message)
raise self.DecryptionRequestFailed(message)
return response.content
worker_pool = WorkerPool(
worker=worker,
value_factory=self.DecryptionRequestFactory(
ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold
),
target_successes=threshold,
timeout=timeout,
)
worker_pool.start()
try:
successes = worker_pool.block_until_target_successes()
except (WorkerPool.OutOfValues, WorkerPool.TimedOut):
# It's possible to raise some other exceptions here but we will use the logic below.
successes = worker_pool.get_successes()
finally:
worker_pool.cancel()
worker_pool.join()
failures = worker_pool.get_failures()
return successes, failures

View File

@ -262,7 +262,9 @@ class RestMiddleware:
)
return response
def get_decryption_share(self, ursula: 'Ursula', decryption_request_bytes: bytes):
def get_encrypted_decryption_share(
self, ursula: "Ursula", decryption_request_bytes: bytes
):
response = self.client.post(
node_or_sprout=ursula,
path=f"decrypt",

View File

@ -2,7 +2,6 @@
import json
import random
from collections import defaultdict
from json import JSONDecodeError
from typing import Dict, List, Sequence, Tuple
from eth_typing.evm import ChecksumAddress
@ -22,11 +21,10 @@ from nucypher_core.umbral import (
VerificationError,
VerifiedCapsuleFrag,
)
from twisted.logger import Logger
from nucypher.crypto.signing import InvalidSignature
from nucypher.network.client import ThresholdAccessControlClient
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.nodes import Learner
from nucypher.policy.conditions.exceptions import InvalidConditionContext
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
from nucypher.policy.kits import RetrievalResult
@ -39,7 +37,7 @@ class RetrievalError:
class RetrievalPlan:
"""
An emphemeral object providing a service of selecting Ursulas for reencryption requests
An ephemeral object providing a service of selecting Ursulas for re-encryption requests
during retrieval.
"""
@ -166,49 +164,13 @@ class RetrievalWorkOrder:
return rust_lingos
class RetrievalClient:
class PRERetrievalClient(ThresholdAccessControlClient):
"""
Capsule frag retrieval machinery shared between Bob and Porter.
"""
def __init__(self, learner: Learner):
self._learner = learner
self.log = Logger(self.__class__.__name__)
def _ensure_ursula_availability(self, treasure_map: TreasureMap, timeout=10):
"""
Make sure we know enough nodes from the treasure map to decrypt;
otherwise block and wait for them to come online.
"""
# OK, so we're going to need to do some network activity for this retrieval.
# Let's make sure we've seeded.
if not self._learner.done_seeding:
self._learner.learn_from_teacher_node()
ursulas_in_map = treasure_map.destinations.keys()
# TODO (#1995): when that issue is fixed, conversion is no longer needed
ursulas_in_map = [to_checksum_address(bytes(address)) for address in ursulas_in_map]
all_known_ursulas = self._learner.known_nodes.addresses()
# Push all unknown Ursulas from the map in the queue for learning
unknown_ursulas = ursulas_in_map - all_known_ursulas
# If we know enough to decrypt, we can proceed.
known_ursulas = ursulas_in_map & all_known_ursulas
if len(known_ursulas) >= treasure_map.threshold:
return
# | <--- shares ---> |
# | <--- threshold ---> | <--- allow_missing ---> |
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
allow_missing = len(treasure_map.destinations) - treasure_map.threshold
self._learner.block_until_specific_nodes_are_known(unknown_ursulas,
timeout=timeout,
allow_missing=allow_missing,
learn_on_this_thread=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _request_reencryption(self,
ursula: 'Ursula',
@ -284,7 +246,16 @@ class RetrievalClient:
bob_verifying_key: PublicKey,
**context) -> Tuple[List[RetrievalResult], List[RetrievalError]]:
self._ensure_ursula_availability(treasure_map)
ursulas_in_map = treasure_map.destinations.keys()
# TODO (#1995): when that issue is fixed, conversion is no longer needed
ursulas_in_map = [
to_checksum_address(bytes(address)) for address in ursulas_in_map
]
self._ensure_ursula_availability(
ursulas=ursulas_in_map, threshold=treasure_map.threshold
)
retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits)

View File

@ -330,3 +330,48 @@ class WorkerPool:
break
self._result_queue.put(PRODUCER_STOPPED)
class BatchValueFactory:
def __init__(
self, values: List[Any], required_successes: int, batch_size: int = None
):
if not values:
raise ValueError(f"No available values provided")
if required_successes <= 0:
raise ValueError(
f"Invalid number of successes required ({required_successes})"
)
self.values = values
self.required_successes = required_successes
if len(self.values) < self.required_successes:
raise ValueError(
f"Available values ({len(self.values)} less than required successes {self.required_successes}"
)
self._batch_start_index = 0
if batch_size is not None and batch_size <= 0:
raise ValueError(f"Invalid batch size specified ({batch_size})")
self.batch_size = batch_size if batch_size else required_successes
def __call__(self, successes) -> Optional[List[Any]]:
if successes >= self.required_successes:
# no more work needed to be done
return None
if self._batch_start_index == len(self.values):
# no more values to process
return None
batch_end_index = self._batch_start_index + self.batch_size
if batch_end_index <= len(self.values):
batch = self.values[self._batch_start_index : batch_end_index]
self._batch_start_index = batch_end_index
return batch
else:
# return all remaining values
batch = self.values[self._batch_start_index :]
self._batch_start_index = len(self.values)
return batch

View File

@ -1,9 +1,10 @@
import pytest
import random
import time
from typing import Iterable, Tuple
from nucypher.utilities.concurrency import WorkerPool
import pytest
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
class AllAtOnceFactory:
@ -213,27 +214,18 @@ def test_join(join_worker_pool):
assert t_end - t_start < 3
class BatchFactory:
class TestBatchValueFactory(BatchValueFactory):
def __init__(self, values):
self.values = values
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_sizes = []
def __call__(self, successes):
if successes == 10:
return None
batch_size = 10 - successes
if len(self.values) >= batch_size:
batch = self.values[:batch_size]
self.batch_sizes.append(len(batch))
self.values = self.values[batch_size:]
return batch
elif len(self.values) > 0:
self.batch_sizes.append(len(self.values))
return self.values
self.values = None
else:
return None
result = super().__call__(successes)
if result:
self.batch_sizes.append(len(result))
return result
def test_batched_value_generation(join_worker_pool):
@ -248,7 +240,7 @@ def test_batched_value_generation(join_worker_pool):
],
seed=123)
factory = BatchFactory(list(outcomes))
factory = TestBatchValueFactory(values=list(outcomes), required_successes=10)
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
join_worker_pool(pool)

View File

@ -0,0 +1,196 @@
import pytest
from nucypher.utilities.concurrency import BatchValueFactory
NUM_VALUES = 20
@pytest.fixture(scope="module")
def values():
values = []
for i in range(0, NUM_VALUES):
values.append(i)
return values
def test_batch_value_factory_invalid_values(values):
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=0)
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=1)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=5)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=2, batch_size=0)
def test_batch_value_factory_all_successes_no_specified_batching(values):
target_successes = NUM_VALUES
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# number of successes returned since no batching provided
value_list = value_factory(successes=0)
assert len(value_list) == target_successes, "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_no_specified_batching_no_more_values_after_target_successes(
values,
):
target_successes = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // 3):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
for i in range(NUM_VALUES // 3, NUM_VALUES):
value_list = value_factory(successes=target_successes)
assert (
not value_list
), "there are more values but no more is needed since target successes attained"
def test_batch_value_factory_no_batching_no_success_multiple_calls(values):
target_successes = 4
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# list all done but get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# list all done but get list again
value_list = value_factory(
successes=1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_no_batching_no_success_multiple_calls_non_divisible_successes(
values,
):
target_successes = 6
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# should be able to get 4 lists
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# last request
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % target_successes, "remaining list returned"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# get list again
value_list = value_factory(
successes=target_successes - 1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_batching_individual(values):
target_successes = NUM_VALUES
batch_size = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_divisible(values):
target_successes = NUM_VALUES
batch_size = 5
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided (3x here)
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=target_successes - 1)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_non_divisible(values):
target_successes = NUM_VALUES
batch_size = 7
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# one more
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % batch_size, "remainder of list returned"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"