diff --git a/newsfragments/3126.dev.rst b/newsfragments/3126.dev.rst new file mode 100644 index 000000000..e69de29bb diff --git a/nucypher/characters/lawful.py b/nucypher/characters/lawful.py index 888354c68..b49367196 100644 --- a/nucypher/characters/lawful.py +++ b/nucypher/characters/lawful.py @@ -97,11 +97,12 @@ from nucypher.crypto.powers import ( TLSHostingPower, TransactingPower, ) +from nucypher.network.decryption import ThresholdDecryptionClient from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.middleware import RestMiddleware from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher from nucypher.network.protocols import parse_node_uri -from nucypher.network.retrieval import RetrievalClient +from nucypher.network.retrieval import PRERetrievalClient from nucypher.network.server import ProxyRESTServer, make_rest_app from nucypher.network.trackers import AvailabilityTracker from nucypher.policy.conditions.types import LingoList @@ -502,7 +503,7 @@ class Bob(Character): retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits] # Retrieve capsule frags - client = RetrievalClient(learner=self) + client = PRERetrievalClient(learner=self) retrieval_results, _ = client.retrieve_cfrags( treasure_map=treasure_map, retrieval_kits=retrieval_kits, @@ -570,51 +571,46 @@ class Bob(Character): threshold: int, variant: FerveoVariant, context: Optional[dict] = None, - ) -> List[DecryptionShareSimple]: + ) -> Dict[ + ChecksumAddress, Union[DecryptionShareSimple, DecryptionSharePrecomputed] + ]: if variant == FerveoVariant.PRECOMPUTED: share_type = DecryptionSharePrecomputed elif variant == FerveoVariant.SIMPLE: share_type = DecryptionShareSimple - gathered_shares = list() + decryption_request_mapping = {} for ursula in cohort: conditions = Conditions(json.dumps(lingo)) if context: context = Context(json.dumps(context)) decryption_request = ThresholdDecryptionRequest( id=ritual_id, + variant=int(variant.value), ciphertext=bytes(ciphertext), conditions=conditions, context=context, - variant=int(variant.value), ) + decryption_request_mapping[ + to_checksum_address(ursula.checksum_address) + ] = bytes(decryption_request) - try: - response = self.network_middleware.get_decryption_share(ursula, bytes(decryption_request)) - except NodeSeemsToBeDown as e: - self.log.warn(f"Node {ursula} is unreachable. {e}") - continue - if response.status_code != 200: - self.log.warn(f"Node {ursula} returned {response.status_code}.") - continue + decryption_client = ThresholdDecryptionClient(learner=self) + successes, failures = decryption_client.gather_encrypted_decryption_shares( + encrypted_requests=decryption_request_mapping, threshold=threshold + ) - decryption_response = ThresholdDecryptionResponse.from_bytes( - response.content - ) - decryption_share = share_type.from_bytes( - decryption_response.decryption_share - ) - gathered_shares.append(decryption_share) - self.log.debug(f"Got {len(gathered_shares)}/{threshold} shares so far...") - - if variant == FerveoVariant.SIMPLE and (len(gathered_shares) == threshold): - # security threshold reached - break - - if len(gathered_shares) < threshold: + if len(successes) < threshold: raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt") self.log.debug(f"Got enough shares to decrypt.") + gathered_shares = {} + for provider_address, response_bytes in successes.items(): + decryption_response = ThresholdDecryptionResponse.from_bytes(response_bytes) + decryption_share = share_type.from_bytes( + decryption_response.decryption_share + ) + gathered_shares[provider_address] = decryption_share return gathered_shares def threshold_decrypt(self, @@ -647,8 +643,12 @@ class Bob(Character): except AttributeError: raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}") - threshold = (ritual.shares // 2) + 1 # TODO: #3095 get this from the ritual / put it on-chain? - shares = self.gather_decryption_shares( + threshold = ( + (ritual.shares // 2) + 1 + if variant == FerveoVariant.SIMPLE + else ritual.shares + ) # TODO: #3095 get this from the ritual / put it on-chain? + decryption_shares = self.gather_decryption_shares( ritual_id=ritual_id, cohort=ursulas, ciphertext=ciphertext, @@ -662,10 +662,15 @@ class Bob(Character): # TODO: Bob can call.verify here instead of aggregating the shares. # if the DKG parameters are not provided, we need to # aggregate the transcripts and derive them. + + # TODO we don't need all ursulas, only threshold of them + # ursulas = [u for u in ursulas if u.checksum_address in decryption_shares] params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold) # TODO: compare the results with the on-chain records (Coordinator). - return self.__decrypt(shares, ciphertext, conditions, params, variant) + return self.__decrypt( + list(decryption_shares.values()), ciphertext, conditions, params, variant + ) @staticmethod def __decrypt( diff --git a/nucypher/network/__init__.py b/nucypher/network/__init__.py index 8b1378917..e69de29bb 100644 --- a/nucypher/network/__init__.py +++ b/nucypher/network/__init__.py @@ -1 +0,0 @@ - diff --git a/nucypher/network/client.py b/nucypher/network/client.py new file mode 100644 index 000000000..081f86cd6 --- /dev/null +++ b/nucypher/network/client.py @@ -0,0 +1,50 @@ +from typing import List + +from eth_typing import ChecksumAddress + +from nucypher.network.nodes import Learner +from nucypher.utilities.logging import Logger + + +class ThresholdAccessControlClient: + """ + Client for communicating with access control nodes on the Threshold Network. + """ + + def __init__(self, learner: Learner): + self._learner = learner + self.log = Logger(self.__class__.__name__) + + def _ensure_ursula_availability( + self, ursulas: List[ChecksumAddress], threshold: int, timeout=10 + ): + """ + Make sure we know enough nodes from the treasure map to decrypt; + otherwise block and wait for them to come online. + """ + + # OK, so we're going to need to do some network activity for this retrieval. + # Let's make sure we've seeded. + if not self._learner.done_seeding: + self._learner.learn_from_teacher_node() + + all_known_ursulas = self._learner.known_nodes.addresses() + + # Push all unknown Ursulas from the map in the queue for learning + unknown_ursulas = ursulas - all_known_ursulas + + # If we know enough to decrypt, we can proceed. + known_ursulas = ursulas & all_known_ursulas + if len(known_ursulas) >= threshold: + return + + # | <--- shares ---> | + # | <--- threshold ---> | <--- allow_missing ---> | + # | <--- known_ursulas ---> | <--- unknown_ursulas ---> | + allow_missing = len(ursulas) - threshold + self._learner.block_until_specific_nodes_are_known( + unknown_ursulas, + timeout=timeout, + allow_missing=allow_missing, + learn_on_this_thread=True, + ) diff --git a/nucypher/network/decryption.py b/nucypher/network/decryption.py new file mode 100644 index 000000000..5123b40a7 --- /dev/null +++ b/nucypher/network/decryption.py @@ -0,0 +1,74 @@ +from typing import Dict, List, Tuple + +from eth_typing import ChecksumAddress + +from nucypher.network.client import ThresholdAccessControlClient +from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool + + +class ThresholdDecryptionClient(ThresholdAccessControlClient): + class DecryptionRequestFailed(Exception): + """Raised when a decryption request returns a non-zero status code.""" + + class DecryptionRequestFactory(BatchValueFactory): + def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int): + # TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter + super().__init__(values=ursula_to_contact, required_successes=threshold) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def gather_encrypted_decryption_shares( + self, + encrypted_requests: Dict[ChecksumAddress, bytes], + threshold: int, + timeout: float = 10, + ) -> Tuple[Dict[ChecksumAddress, bytes], Dict[ChecksumAddress, str]]: + self._ensure_ursula_availability( + ursulas=list(encrypted_requests.keys()), + threshold=threshold, + timeout=timeout, + ) + + def worker(ursula_address: ChecksumAddress) -> bytes: + encrypted_request = encrypted_requests[ursula_address] + + try: + node_or_sprout = self._learner.known_nodes[ursula_address] + node_or_sprout.mature() + response = ( + self._learner.network_middleware.get_encrypted_decryption_share( + node_or_sprout, encrypted_request + ) + ) + except Exception as e: + self.log.warn(f"Node {ursula_address} raised {e}") + raise + else: + if response.status_code != 200: + message = f"Node {ursula_address} returned {response.status_code} - {response.content}." + self.log.warn(message) + raise self.DecryptionRequestFailed(message) + + return response.content + + worker_pool = WorkerPool( + worker=worker, + value_factory=self.DecryptionRequestFactory( + ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold + ), + target_successes=threshold, + timeout=timeout, + ) + worker_pool.start() + try: + successes = worker_pool.block_until_target_successes() + except (WorkerPool.OutOfValues, WorkerPool.TimedOut): + # It's possible to raise some other exceptions here but we will use the logic below. + successes = worker_pool.get_successes() + finally: + worker_pool.cancel() + worker_pool.join() + failures = worker_pool.get_failures() + + return successes, failures diff --git a/nucypher/network/middleware.py b/nucypher/network/middleware.py index ffec2bf27..5b090b92b 100644 --- a/nucypher/network/middleware.py +++ b/nucypher/network/middleware.py @@ -262,7 +262,9 @@ class RestMiddleware: ) return response - def get_decryption_share(self, ursula: 'Ursula', decryption_request_bytes: bytes): + def get_encrypted_decryption_share( + self, ursula: "Ursula", decryption_request_bytes: bytes + ): response = self.client.post( node_or_sprout=ursula, path=f"decrypt", diff --git a/nucypher/network/retrieval.py b/nucypher/network/retrieval.py index 7d38afd15..05471c939 100644 --- a/nucypher/network/retrieval.py +++ b/nucypher/network/retrieval.py @@ -2,7 +2,6 @@ import json import random from collections import defaultdict -from json import JSONDecodeError from typing import Dict, List, Sequence, Tuple from eth_typing.evm import ChecksumAddress @@ -22,11 +21,10 @@ from nucypher_core.umbral import ( VerificationError, VerifiedCapsuleFrag, ) -from twisted.logger import Logger from nucypher.crypto.signing import InvalidSignature +from nucypher.network.client import ThresholdAccessControlClient from nucypher.network.exceptions import NodeSeemsToBeDown -from nucypher.network.nodes import Learner from nucypher.policy.conditions.exceptions import InvalidConditionContext from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos from nucypher.policy.kits import RetrievalResult @@ -39,7 +37,7 @@ class RetrievalError: class RetrievalPlan: """ - An emphemeral object providing a service of selecting Ursulas for reencryption requests + An ephemeral object providing a service of selecting Ursulas for re-encryption requests during retrieval. """ @@ -166,49 +164,13 @@ class RetrievalWorkOrder: return rust_lingos -class RetrievalClient: +class PRERetrievalClient(ThresholdAccessControlClient): """ Capsule frag retrieval machinery shared between Bob and Porter. """ - def __init__(self, learner: Learner): - self._learner = learner - self.log = Logger(self.__class__.__name__) - - def _ensure_ursula_availability(self, treasure_map: TreasureMap, timeout=10): - """ - Make sure we know enough nodes from the treasure map to decrypt; - otherwise block and wait for them to come online. - """ - - # OK, so we're going to need to do some network activity for this retrieval. - # Let's make sure we've seeded. - if not self._learner.done_seeding: - self._learner.learn_from_teacher_node() - - ursulas_in_map = treasure_map.destinations.keys() - - # TODO (#1995): when that issue is fixed, conversion is no longer needed - ursulas_in_map = [to_checksum_address(bytes(address)) for address in ursulas_in_map] - - all_known_ursulas = self._learner.known_nodes.addresses() - - # Push all unknown Ursulas from the map in the queue for learning - unknown_ursulas = ursulas_in_map - all_known_ursulas - - # If we know enough to decrypt, we can proceed. - known_ursulas = ursulas_in_map & all_known_ursulas - if len(known_ursulas) >= treasure_map.threshold: - return - - # | <--- shares ---> | - # | <--- threshold ---> | <--- allow_missing ---> | - # | <--- known_ursulas ---> | <--- unknown_ursulas ---> | - allow_missing = len(treasure_map.destinations) - treasure_map.threshold - self._learner.block_until_specific_nodes_are_known(unknown_ursulas, - timeout=timeout, - allow_missing=allow_missing, - learn_on_this_thread=True) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _request_reencryption(self, ursula: 'Ursula', @@ -284,7 +246,16 @@ class RetrievalClient: bob_verifying_key: PublicKey, **context) -> Tuple[List[RetrievalResult], List[RetrievalError]]: - self._ensure_ursula_availability(treasure_map) + ursulas_in_map = treasure_map.destinations.keys() + + # TODO (#1995): when that issue is fixed, conversion is no longer needed + ursulas_in_map = [ + to_checksum_address(bytes(address)) for address in ursulas_in_map + ] + + self._ensure_ursula_availability( + ursulas=ursulas_in_map, threshold=treasure_map.threshold + ) retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits) diff --git a/nucypher/utilities/concurrency.py b/nucypher/utilities/concurrency.py index 7f5d1a150..51bf6db36 100644 --- a/nucypher/utilities/concurrency.py +++ b/nucypher/utilities/concurrency.py @@ -330,3 +330,48 @@ class WorkerPool: break self._result_queue.put(PRODUCER_STOPPED) + + +class BatchValueFactory: + def __init__( + self, values: List[Any], required_successes: int, batch_size: int = None + ): + if not values: + raise ValueError(f"No available values provided") + if required_successes <= 0: + raise ValueError( + f"Invalid number of successes required ({required_successes})" + ) + + self.values = values + self.required_successes = required_successes + if len(self.values) < self.required_successes: + raise ValueError( + f"Available values ({len(self.values)} less than required successes {self.required_successes}" + ) + + self._batch_start_index = 0 + + if batch_size is not None and batch_size <= 0: + raise ValueError(f"Invalid batch size specified ({batch_size})") + self.batch_size = batch_size if batch_size else required_successes + + def __call__(self, successes) -> Optional[List[Any]]: + if successes >= self.required_successes: + # no more work needed to be done + return None + + if self._batch_start_index == len(self.values): + # no more values to process + return None + + batch_end_index = self._batch_start_index + self.batch_size + if batch_end_index <= len(self.values): + batch = self.values[self._batch_start_index : batch_end_index] + self._batch_start_index = batch_end_index + return batch + else: + # return all remaining values + batch = self.values[self._batch_start_index :] + self._batch_start_index = len(self.values) + return batch diff --git a/tests/integration/utilities/test_concurrency.py b/tests/integration/utilities/test_concurrency.py index 526c47461..92b06f539 100644 --- a/tests/integration/utilities/test_concurrency.py +++ b/tests/integration/utilities/test_concurrency.py @@ -1,9 +1,10 @@ -import pytest import random import time from typing import Iterable, Tuple -from nucypher.utilities.concurrency import WorkerPool +import pytest + +from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool class AllAtOnceFactory: @@ -213,27 +214,18 @@ def test_join(join_worker_pool): assert t_end - t_start < 3 -class BatchFactory: +class TestBatchValueFactory(BatchValueFactory): - def __init__(self, values): - self.values = values + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.batch_sizes = [] def __call__(self, successes): - if successes == 10: - return None - batch_size = 10 - successes - if len(self.values) >= batch_size: - batch = self.values[:batch_size] - self.batch_sizes.append(len(batch)) - self.values = self.values[batch_size:] - return batch - elif len(self.values) > 0: - self.batch_sizes.append(len(self.values)) - return self.values - self.values = None - else: - return None + result = super().__call__(successes) + if result: + self.batch_sizes.append(len(result)) + + return result def test_batched_value_generation(join_worker_pool): @@ -248,7 +240,7 @@ def test_batched_value_generation(join_worker_pool): ], seed=123) - factory = BatchFactory(list(outcomes)) + factory = TestBatchValueFactory(values=list(outcomes), required_successes=10) pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5) join_worker_pool(pool) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py new file mode 100644 index 000000000..ec614bc1b --- /dev/null +++ b/tests/unit/test_concurrency.py @@ -0,0 +1,196 @@ +import pytest + +from nucypher.utilities.concurrency import BatchValueFactory + +NUM_VALUES = 20 + + +@pytest.fixture(scope="module") +def values(): + values = [] + for i in range(0, NUM_VALUES): + values.append(i) + + return values + + +def test_batch_value_factory_invalid_values(values): + with pytest.raises(ValueError): + BatchValueFactory(values=[], required_successes=0) + + with pytest.raises(ValueError): + BatchValueFactory(values=[], required_successes=1) + + with pytest.raises(ValueError): + BatchValueFactory(values=[1, 2, 3, 4], required_successes=5) + + with pytest.raises(ValueError): + BatchValueFactory(values=[1, 2, 3, 4], required_successes=2, batch_size=0) + + +def test_batch_value_factory_all_successes_no_specified_batching(values): + target_successes = NUM_VALUES + value_factory = BatchValueFactory( + values=values, required_successes=target_successes + ) + + # number of successes returned since no batching provided + value_list = value_factory(successes=0) + assert len(value_list) == target_successes, "list returned is based on successes" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # get list again + value_list = value_factory(successes=NUM_VALUES) # successes achieved + assert not value_list, "successes achieved and no more values available" + + # get list again + value_list = value_factory(successes=0) # successes not achieved + assert not value_list, "no successes achieved but no more values available" + + +def test_batch_value_factory_no_specified_batching_no_more_values_after_target_successes( + values, +): + target_successes = 1 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes + ) + + for i in range(0, NUM_VALUES // 3): + value_list = value_factory(successes=0) + assert ( + len(value_list) == target_successes + ), "list returned is based on successes" + assert len(values) == NUM_VALUES, "values remained unchanged" + + for i in range(NUM_VALUES // 3, NUM_VALUES): + value_list = value_factory(successes=target_successes) + assert ( + not value_list + ), "there are more values but no more is needed since target successes attained" + + +def test_batch_value_factory_no_batching_no_success_multiple_calls(values): + target_successes = 4 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes + ) + + for i in range(0, NUM_VALUES // target_successes): + value_list = value_factory(successes=0) + assert ( + len(value_list) == target_successes + ), "list returned is based on successes" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # list all done but get list again + value_list = value_factory(successes=target_successes) # successes achieved + assert not value_list, "successes achieved" + + # list all done but get list again + value_list = value_factory( + successes=1 + ) # not enough successes but list is now empty + assert not value_list, "successes not achieved, but no more values available" + + +def test_batch_value_factory_no_batching_no_success_multiple_calls_non_divisible_successes( + values, +): + target_successes = 6 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes + ) + + # should be able to get 4 lists + for i in range(0, NUM_VALUES // target_successes): + value_list = value_factory(successes=0) + assert ( + len(value_list) == target_successes + ), "list returned is based on successes" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # last request + value_list = value_factory(successes=0) + assert len(value_list) == NUM_VALUES % target_successes, "remaining list returned" + + # get list again + value_list = value_factory(successes=target_successes) # successes achieved + assert not value_list, "successes achieved" + + # get list again + value_list = value_factory( + successes=target_successes - 1 + ) # not enough successes but list is now empty + assert not value_list, "successes not achieved, but no more values available" + + +def test_batch_value_factory_batching_individual(values): + target_successes = NUM_VALUES + batch_size = 1 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes, batch_size=batch_size + ) + + # number of successes returned since no batching provided + for i in range(0, NUM_VALUES // batch_size): + value_list = value_factory(successes=0) + assert len(value_list) == batch_size, "list returned is based on batch size" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # get list again + value_list = value_factory(successes=NUM_VALUES) # successes achieved + assert not value_list, "successes achieved and no more values available" + + # get list again + value_list = value_factory(successes=0) # successes not achieved + assert not value_list, "no successes achieved but no more values available" + + +def test_batch_value_factory_batching_divisible(values): + target_successes = NUM_VALUES + batch_size = 5 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes, batch_size=batch_size + ) + + # number of successes returned since no batching provided (3x here) + for i in range(0, NUM_VALUES // batch_size): + value_list = value_factory(successes=target_successes - 1) + assert len(value_list) == batch_size, "list returned is based on batch size" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # get list again + value_list = value_factory(successes=NUM_VALUES) # successes achieved + assert not value_list, "successes achieved and no more values available" + + # get list again + value_list = value_factory(successes=0) # successes not achieved + assert not value_list, "no successes achieved but no more values available" + + +def test_batch_value_factory_batching_non_divisible(values): + target_successes = NUM_VALUES + batch_size = 7 + value_factory = BatchValueFactory( + values=values, required_successes=target_successes, batch_size=batch_size + ) + + # number of successes returned since no batching provided + for i in range(0, NUM_VALUES // batch_size): + value_list = value_factory(successes=0) + assert len(value_list) == batch_size, "list returned is based on batch size" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # one more + value_list = value_factory(successes=0) + assert len(value_list) == NUM_VALUES % batch_size, "remainder of list returned" + assert len(values) == NUM_VALUES, "values remained unchanged" + + # get list again + value_list = value_factory(successes=target_successes) # successes achieved + assert not value_list, "successes achieved and no more values available" + + # get list again + value_list = value_factory(successes=0) # successes not achieved + assert not value_list, "no successes achieved but no more values available"