mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3126 from derekpierre/bob-porter-conjoined
Ensure that Bob and Porter can use common code for making Threshold Decryption Requestspull/3133/head
commit
8b205ad443
|
@ -97,11 +97,12 @@ from nucypher.crypto.powers import (
|
||||||
TLSHostingPower,
|
TLSHostingPower,
|
||||||
TransactingPower,
|
TransactingPower,
|
||||||
)
|
)
|
||||||
|
from nucypher.network.decryption import ThresholdDecryptionClient
|
||||||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||||
from nucypher.network.middleware import RestMiddleware
|
from nucypher.network.middleware import RestMiddleware
|
||||||
from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher
|
from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher
|
||||||
from nucypher.network.protocols import parse_node_uri
|
from nucypher.network.protocols import parse_node_uri
|
||||||
from nucypher.network.retrieval import RetrievalClient
|
from nucypher.network.retrieval import PRERetrievalClient
|
||||||
from nucypher.network.server import ProxyRESTServer, make_rest_app
|
from nucypher.network.server import ProxyRESTServer, make_rest_app
|
||||||
from nucypher.network.trackers import AvailabilityTracker
|
from nucypher.network.trackers import AvailabilityTracker
|
||||||
from nucypher.policy.conditions.types import LingoList
|
from nucypher.policy.conditions.types import LingoList
|
||||||
|
@ -502,7 +503,7 @@ class Bob(Character):
|
||||||
retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits]
|
retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits]
|
||||||
|
|
||||||
# Retrieve capsule frags
|
# Retrieve capsule frags
|
||||||
client = RetrievalClient(learner=self)
|
client = PRERetrievalClient(learner=self)
|
||||||
retrieval_results, _ = client.retrieve_cfrags(
|
retrieval_results, _ = client.retrieve_cfrags(
|
||||||
treasure_map=treasure_map,
|
treasure_map=treasure_map,
|
||||||
retrieval_kits=retrieval_kits,
|
retrieval_kits=retrieval_kits,
|
||||||
|
@ -570,51 +571,46 @@ class Bob(Character):
|
||||||
threshold: int,
|
threshold: int,
|
||||||
variant: FerveoVariant,
|
variant: FerveoVariant,
|
||||||
context: Optional[dict] = None,
|
context: Optional[dict] = None,
|
||||||
) -> List[DecryptionShareSimple]:
|
) -> Dict[
|
||||||
|
ChecksumAddress, Union[DecryptionShareSimple, DecryptionSharePrecomputed]
|
||||||
|
]:
|
||||||
if variant == FerveoVariant.PRECOMPUTED:
|
if variant == FerveoVariant.PRECOMPUTED:
|
||||||
share_type = DecryptionSharePrecomputed
|
share_type = DecryptionSharePrecomputed
|
||||||
elif variant == FerveoVariant.SIMPLE:
|
elif variant == FerveoVariant.SIMPLE:
|
||||||
share_type = DecryptionShareSimple
|
share_type = DecryptionShareSimple
|
||||||
|
|
||||||
gathered_shares = list()
|
decryption_request_mapping = {}
|
||||||
for ursula in cohort:
|
for ursula in cohort:
|
||||||
conditions = Conditions(json.dumps(lingo))
|
conditions = Conditions(json.dumps(lingo))
|
||||||
if context:
|
if context:
|
||||||
context = Context(json.dumps(context))
|
context = Context(json.dumps(context))
|
||||||
decryption_request = ThresholdDecryptionRequest(
|
decryption_request = ThresholdDecryptionRequest(
|
||||||
id=ritual_id,
|
id=ritual_id,
|
||||||
|
variant=int(variant.value),
|
||||||
ciphertext=bytes(ciphertext),
|
ciphertext=bytes(ciphertext),
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
context=context,
|
context=context,
|
||||||
variant=int(variant.value),
|
|
||||||
)
|
)
|
||||||
|
decryption_request_mapping[
|
||||||
|
to_checksum_address(ursula.checksum_address)
|
||||||
|
] = bytes(decryption_request)
|
||||||
|
|
||||||
try:
|
decryption_client = ThresholdDecryptionClient(learner=self)
|
||||||
response = self.network_middleware.get_decryption_share(ursula, bytes(decryption_request))
|
successes, failures = decryption_client.gather_encrypted_decryption_shares(
|
||||||
except NodeSeemsToBeDown as e:
|
encrypted_requests=decryption_request_mapping, threshold=threshold
|
||||||
self.log.warn(f"Node {ursula} is unreachable. {e}")
|
)
|
||||||
continue
|
|
||||||
if response.status_code != 200:
|
|
||||||
self.log.warn(f"Node {ursula} returned {response.status_code}.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
decryption_response = ThresholdDecryptionResponse.from_bytes(
|
if len(successes) < threshold:
|
||||||
response.content
|
|
||||||
)
|
|
||||||
decryption_share = share_type.from_bytes(
|
|
||||||
decryption_response.decryption_share
|
|
||||||
)
|
|
||||||
gathered_shares.append(decryption_share)
|
|
||||||
self.log.debug(f"Got {len(gathered_shares)}/{threshold} shares so far...")
|
|
||||||
|
|
||||||
if variant == FerveoVariant.SIMPLE and (len(gathered_shares) == threshold):
|
|
||||||
# security threshold reached
|
|
||||||
break
|
|
||||||
|
|
||||||
if len(gathered_shares) < threshold:
|
|
||||||
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt")
|
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt")
|
||||||
self.log.debug(f"Got enough shares to decrypt.")
|
self.log.debug(f"Got enough shares to decrypt.")
|
||||||
|
|
||||||
|
gathered_shares = {}
|
||||||
|
for provider_address, response_bytes in successes.items():
|
||||||
|
decryption_response = ThresholdDecryptionResponse.from_bytes(response_bytes)
|
||||||
|
decryption_share = share_type.from_bytes(
|
||||||
|
decryption_response.decryption_share
|
||||||
|
)
|
||||||
|
gathered_shares[provider_address] = decryption_share
|
||||||
return gathered_shares
|
return gathered_shares
|
||||||
|
|
||||||
def threshold_decrypt(self,
|
def threshold_decrypt(self,
|
||||||
|
@ -647,8 +643,12 @@ class Bob(Character):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}")
|
raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}")
|
||||||
|
|
||||||
threshold = (ritual.shares // 2) + 1 # TODO: #3095 get this from the ritual / put it on-chain?
|
threshold = (
|
||||||
shares = self.gather_decryption_shares(
|
(ritual.shares // 2) + 1
|
||||||
|
if variant == FerveoVariant.SIMPLE
|
||||||
|
else ritual.shares
|
||||||
|
) # TODO: #3095 get this from the ritual / put it on-chain?
|
||||||
|
decryption_shares = self.gather_decryption_shares(
|
||||||
ritual_id=ritual_id,
|
ritual_id=ritual_id,
|
||||||
cohort=ursulas,
|
cohort=ursulas,
|
||||||
ciphertext=ciphertext,
|
ciphertext=ciphertext,
|
||||||
|
@ -662,10 +662,15 @@ class Bob(Character):
|
||||||
# TODO: Bob can call.verify here instead of aggregating the shares.
|
# TODO: Bob can call.verify here instead of aggregating the shares.
|
||||||
# if the DKG parameters are not provided, we need to
|
# if the DKG parameters are not provided, we need to
|
||||||
# aggregate the transcripts and derive them.
|
# aggregate the transcripts and derive them.
|
||||||
|
|
||||||
|
# TODO we don't need all ursulas, only threshold of them
|
||||||
|
# ursulas = [u for u in ursulas if u.checksum_address in decryption_shares]
|
||||||
params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold)
|
params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold)
|
||||||
# TODO: compare the results with the on-chain records (Coordinator).
|
# TODO: compare the results with the on-chain records (Coordinator).
|
||||||
|
|
||||||
return self.__decrypt(shares, ciphertext, conditions, params, variant)
|
return self.__decrypt(
|
||||||
|
list(decryption_shares.values()), ciphertext, conditions, params, variant
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __decrypt(
|
def __decrypt(
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from eth_typing import ChecksumAddress
|
||||||
|
|
||||||
|
from nucypher.network.nodes import Learner
|
||||||
|
from nucypher.utilities.logging import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdAccessControlClient:
|
||||||
|
"""
|
||||||
|
Client for communicating with access control nodes on the Threshold Network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learner: Learner):
|
||||||
|
self._learner = learner
|
||||||
|
self.log = Logger(self.__class__.__name__)
|
||||||
|
|
||||||
|
def _ensure_ursula_availability(
|
||||||
|
self, ursulas: List[ChecksumAddress], threshold: int, timeout=10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make sure we know enough nodes from the treasure map to decrypt;
|
||||||
|
otherwise block and wait for them to come online.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# OK, so we're going to need to do some network activity for this retrieval.
|
||||||
|
# Let's make sure we've seeded.
|
||||||
|
if not self._learner.done_seeding:
|
||||||
|
self._learner.learn_from_teacher_node()
|
||||||
|
|
||||||
|
all_known_ursulas = self._learner.known_nodes.addresses()
|
||||||
|
|
||||||
|
# Push all unknown Ursulas from the map in the queue for learning
|
||||||
|
unknown_ursulas = ursulas - all_known_ursulas
|
||||||
|
|
||||||
|
# If we know enough to decrypt, we can proceed.
|
||||||
|
known_ursulas = ursulas & all_known_ursulas
|
||||||
|
if len(known_ursulas) >= threshold:
|
||||||
|
return
|
||||||
|
|
||||||
|
# | <--- shares ---> |
|
||||||
|
# | <--- threshold ---> | <--- allow_missing ---> |
|
||||||
|
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
|
||||||
|
allow_missing = len(ursulas) - threshold
|
||||||
|
self._learner.block_until_specific_nodes_are_known(
|
||||||
|
unknown_ursulas,
|
||||||
|
timeout=timeout,
|
||||||
|
allow_missing=allow_missing,
|
||||||
|
learn_on_this_thread=True,
|
||||||
|
)
|
|
@ -0,0 +1,74 @@
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from eth_typing import ChecksumAddress
|
||||||
|
|
||||||
|
from nucypher.network.client import ThresholdAccessControlClient
|
||||||
|
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
||||||
|
class DecryptionRequestFailed(Exception):
|
||||||
|
"""Raised when a decryption request returns a non-zero status code."""
|
||||||
|
|
||||||
|
class DecryptionRequestFactory(BatchValueFactory):
|
||||||
|
def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int):
|
||||||
|
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
|
||||||
|
super().__init__(values=ursula_to_contact, required_successes=threshold)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def gather_encrypted_decryption_shares(
|
||||||
|
self,
|
||||||
|
encrypted_requests: Dict[ChecksumAddress, bytes],
|
||||||
|
threshold: int,
|
||||||
|
timeout: float = 10,
|
||||||
|
) -> Tuple[Dict[ChecksumAddress, bytes], Dict[ChecksumAddress, str]]:
|
||||||
|
self._ensure_ursula_availability(
|
||||||
|
ursulas=list(encrypted_requests.keys()),
|
||||||
|
threshold=threshold,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def worker(ursula_address: ChecksumAddress) -> bytes:
|
||||||
|
encrypted_request = encrypted_requests[ursula_address]
|
||||||
|
|
||||||
|
try:
|
||||||
|
node_or_sprout = self._learner.known_nodes[ursula_address]
|
||||||
|
node_or_sprout.mature()
|
||||||
|
response = (
|
||||||
|
self._learner.network_middleware.get_encrypted_decryption_share(
|
||||||
|
node_or_sprout, encrypted_request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.warn(f"Node {ursula_address} raised {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
if response.status_code != 200:
|
||||||
|
message = f"Node {ursula_address} returned {response.status_code} - {response.content}."
|
||||||
|
self.log.warn(message)
|
||||||
|
raise self.DecryptionRequestFailed(message)
|
||||||
|
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
worker_pool = WorkerPool(
|
||||||
|
worker=worker,
|
||||||
|
value_factory=self.DecryptionRequestFactory(
|
||||||
|
ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold
|
||||||
|
),
|
||||||
|
target_successes=threshold,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
worker_pool.start()
|
||||||
|
try:
|
||||||
|
successes = worker_pool.block_until_target_successes()
|
||||||
|
except (WorkerPool.OutOfValues, WorkerPool.TimedOut):
|
||||||
|
# It's possible to raise some other exceptions here but we will use the logic below.
|
||||||
|
successes = worker_pool.get_successes()
|
||||||
|
finally:
|
||||||
|
worker_pool.cancel()
|
||||||
|
worker_pool.join()
|
||||||
|
failures = worker_pool.get_failures()
|
||||||
|
|
||||||
|
return successes, failures
|
|
@ -262,7 +262,9 @@ class RestMiddleware:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_decryption_share(self, ursula: 'Ursula', decryption_request_bytes: bytes):
|
def get_encrypted_decryption_share(
|
||||||
|
self, ursula: "Ursula", decryption_request_bytes: bytes
|
||||||
|
):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
node_or_sprout=ursula,
|
node_or_sprout=ursula,
|
||||||
path=f"decrypt",
|
path=f"decrypt",
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from json import JSONDecodeError
|
|
||||||
from typing import Dict, List, Sequence, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
from eth_typing.evm import ChecksumAddress
|
from eth_typing.evm import ChecksumAddress
|
||||||
|
@ -22,11 +21,10 @@ from nucypher_core.umbral import (
|
||||||
VerificationError,
|
VerificationError,
|
||||||
VerifiedCapsuleFrag,
|
VerifiedCapsuleFrag,
|
||||||
)
|
)
|
||||||
from twisted.logger import Logger
|
|
||||||
|
|
||||||
from nucypher.crypto.signing import InvalidSignature
|
from nucypher.crypto.signing import InvalidSignature
|
||||||
|
from nucypher.network.client import ThresholdAccessControlClient
|
||||||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||||
from nucypher.network.nodes import Learner
|
|
||||||
from nucypher.policy.conditions.exceptions import InvalidConditionContext
|
from nucypher.policy.conditions.exceptions import InvalidConditionContext
|
||||||
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
|
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
|
||||||
from nucypher.policy.kits import RetrievalResult
|
from nucypher.policy.kits import RetrievalResult
|
||||||
|
@ -39,7 +37,7 @@ class RetrievalError:
|
||||||
|
|
||||||
class RetrievalPlan:
|
class RetrievalPlan:
|
||||||
"""
|
"""
|
||||||
An emphemeral object providing a service of selecting Ursulas for reencryption requests
|
An ephemeral object providing a service of selecting Ursulas for re-encryption requests
|
||||||
during retrieval.
|
during retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -166,49 +164,13 @@ class RetrievalWorkOrder:
|
||||||
return rust_lingos
|
return rust_lingos
|
||||||
|
|
||||||
|
|
||||||
class RetrievalClient:
|
class PRERetrievalClient(ThresholdAccessControlClient):
|
||||||
"""
|
"""
|
||||||
Capsule frag retrieval machinery shared between Bob and Porter.
|
Capsule frag retrieval machinery shared between Bob and Porter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learner: Learner):
|
def __init__(self, *args, **kwargs):
|
||||||
self._learner = learner
|
super().__init__(*args, **kwargs)
|
||||||
self.log = Logger(self.__class__.__name__)
|
|
||||||
|
|
||||||
def _ensure_ursula_availability(self, treasure_map: TreasureMap, timeout=10):
|
|
||||||
"""
|
|
||||||
Make sure we know enough nodes from the treasure map to decrypt;
|
|
||||||
otherwise block and wait for them to come online.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# OK, so we're going to need to do some network activity for this retrieval.
|
|
||||||
# Let's make sure we've seeded.
|
|
||||||
if not self._learner.done_seeding:
|
|
||||||
self._learner.learn_from_teacher_node()
|
|
||||||
|
|
||||||
ursulas_in_map = treasure_map.destinations.keys()
|
|
||||||
|
|
||||||
# TODO (#1995): when that issue is fixed, conversion is no longer needed
|
|
||||||
ursulas_in_map = [to_checksum_address(bytes(address)) for address in ursulas_in_map]
|
|
||||||
|
|
||||||
all_known_ursulas = self._learner.known_nodes.addresses()
|
|
||||||
|
|
||||||
# Push all unknown Ursulas from the map in the queue for learning
|
|
||||||
unknown_ursulas = ursulas_in_map - all_known_ursulas
|
|
||||||
|
|
||||||
# If we know enough to decrypt, we can proceed.
|
|
||||||
known_ursulas = ursulas_in_map & all_known_ursulas
|
|
||||||
if len(known_ursulas) >= treasure_map.threshold:
|
|
||||||
return
|
|
||||||
|
|
||||||
# | <--- shares ---> |
|
|
||||||
# | <--- threshold ---> | <--- allow_missing ---> |
|
|
||||||
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
|
|
||||||
allow_missing = len(treasure_map.destinations) - treasure_map.threshold
|
|
||||||
self._learner.block_until_specific_nodes_are_known(unknown_ursulas,
|
|
||||||
timeout=timeout,
|
|
||||||
allow_missing=allow_missing,
|
|
||||||
learn_on_this_thread=True)
|
|
||||||
|
|
||||||
def _request_reencryption(self,
|
def _request_reencryption(self,
|
||||||
ursula: 'Ursula',
|
ursula: 'Ursula',
|
||||||
|
@ -284,7 +246,16 @@ class RetrievalClient:
|
||||||
bob_verifying_key: PublicKey,
|
bob_verifying_key: PublicKey,
|
||||||
**context) -> Tuple[List[RetrievalResult], List[RetrievalError]]:
|
**context) -> Tuple[List[RetrievalResult], List[RetrievalError]]:
|
||||||
|
|
||||||
self._ensure_ursula_availability(treasure_map)
|
ursulas_in_map = treasure_map.destinations.keys()
|
||||||
|
|
||||||
|
# TODO (#1995): when that issue is fixed, conversion is no longer needed
|
||||||
|
ursulas_in_map = [
|
||||||
|
to_checksum_address(bytes(address)) for address in ursulas_in_map
|
||||||
|
]
|
||||||
|
|
||||||
|
self._ensure_ursula_availability(
|
||||||
|
ursulas=ursulas_in_map, threshold=treasure_map.threshold
|
||||||
|
)
|
||||||
|
|
||||||
retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits)
|
retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits)
|
||||||
|
|
||||||
|
|
|
@ -330,3 +330,48 @@ class WorkerPool:
|
||||||
break
|
break
|
||||||
|
|
||||||
self._result_queue.put(PRODUCER_STOPPED)
|
self._result_queue.put(PRODUCER_STOPPED)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchValueFactory:
|
||||||
|
def __init__(
|
||||||
|
self, values: List[Any], required_successes: int, batch_size: int = None
|
||||||
|
):
|
||||||
|
if not values:
|
||||||
|
raise ValueError(f"No available values provided")
|
||||||
|
if required_successes <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid number of successes required ({required_successes})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.values = values
|
||||||
|
self.required_successes = required_successes
|
||||||
|
if len(self.values) < self.required_successes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Available values ({len(self.values)} less than required successes {self.required_successes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._batch_start_index = 0
|
||||||
|
|
||||||
|
if batch_size is not None and batch_size <= 0:
|
||||||
|
raise ValueError(f"Invalid batch size specified ({batch_size})")
|
||||||
|
self.batch_size = batch_size if batch_size else required_successes
|
||||||
|
|
||||||
|
def __call__(self, successes) -> Optional[List[Any]]:
|
||||||
|
if successes >= self.required_successes:
|
||||||
|
# no more work needed to be done
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._batch_start_index == len(self.values):
|
||||||
|
# no more values to process
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch_end_index = self._batch_start_index + self.batch_size
|
||||||
|
if batch_end_index <= len(self.values):
|
||||||
|
batch = self.values[self._batch_start_index : batch_end_index]
|
||||||
|
self._batch_start_index = batch_end_index
|
||||||
|
return batch
|
||||||
|
else:
|
||||||
|
# return all remaining values
|
||||||
|
batch = self.values[self._batch_start_index :]
|
||||||
|
self._batch_start_index = len(self.values)
|
||||||
|
return batch
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import pytest
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Tuple
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
from nucypher.utilities.concurrency import WorkerPool
|
import pytest
|
||||||
|
|
||||||
|
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
|
||||||
|
|
||||||
|
|
||||||
class AllAtOnceFactory:
|
class AllAtOnceFactory:
|
||||||
|
@ -213,27 +214,18 @@ def test_join(join_worker_pool):
|
||||||
assert t_end - t_start < 3
|
assert t_end - t_start < 3
|
||||||
|
|
||||||
|
|
||||||
class BatchFactory:
|
class TestBatchValueFactory(BatchValueFactory):
|
||||||
|
|
||||||
def __init__(self, values):
|
def __init__(self, *args, **kwargs):
|
||||||
self.values = values
|
super().__init__(*args, **kwargs)
|
||||||
self.batch_sizes = []
|
self.batch_sizes = []
|
||||||
|
|
||||||
def __call__(self, successes):
|
def __call__(self, successes):
|
||||||
if successes == 10:
|
result = super().__call__(successes)
|
||||||
return None
|
if result:
|
||||||
batch_size = 10 - successes
|
self.batch_sizes.append(len(result))
|
||||||
if len(self.values) >= batch_size:
|
|
||||||
batch = self.values[:batch_size]
|
return result
|
||||||
self.batch_sizes.append(len(batch))
|
|
||||||
self.values = self.values[batch_size:]
|
|
||||||
return batch
|
|
||||||
elif len(self.values) > 0:
|
|
||||||
self.batch_sizes.append(len(self.values))
|
|
||||||
return self.values
|
|
||||||
self.values = None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def test_batched_value_generation(join_worker_pool):
|
def test_batched_value_generation(join_worker_pool):
|
||||||
|
@ -248,7 +240,7 @@ def test_batched_value_generation(join_worker_pool):
|
||||||
],
|
],
|
||||||
seed=123)
|
seed=123)
|
||||||
|
|
||||||
factory = BatchFactory(list(outcomes))
|
factory = TestBatchValueFactory(values=list(outcomes), required_successes=10)
|
||||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
|
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
|
||||||
join_worker_pool(pool)
|
join_worker_pool(pool)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nucypher.utilities.concurrency import BatchValueFactory
|
||||||
|
|
||||||
|
NUM_VALUES = 20
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def values():
|
||||||
|
values = []
|
||||||
|
for i in range(0, NUM_VALUES):
|
||||||
|
values.append(i)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_invalid_values(values):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BatchValueFactory(values=[], required_successes=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BatchValueFactory(values=[], required_successes=1)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BatchValueFactory(values=[1, 2, 3, 4], required_successes=5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BatchValueFactory(values=[1, 2, 3, 4], required_successes=2, batch_size=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_all_successes_no_specified_batching(values):
|
||||||
|
target_successes = NUM_VALUES
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes
|
||||||
|
)
|
||||||
|
|
||||||
|
# number of successes returned since no batching provided
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert len(value_list) == target_successes, "list returned is based on successes"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=NUM_VALUES) # successes achieved
|
||||||
|
assert not value_list, "successes achieved and no more values available"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=0) # successes not achieved
|
||||||
|
assert not value_list, "no successes achieved but no more values available"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_no_specified_batching_no_more_values_after_target_successes(
|
||||||
|
values,
|
||||||
|
):
|
||||||
|
target_successes = 1
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, NUM_VALUES // 3):
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert (
|
||||||
|
len(value_list) == target_successes
|
||||||
|
), "list returned is based on successes"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
for i in range(NUM_VALUES // 3, NUM_VALUES):
|
||||||
|
value_list = value_factory(successes=target_successes)
|
||||||
|
assert (
|
||||||
|
not value_list
|
||||||
|
), "there are more values but no more is needed since target successes attained"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_no_batching_no_success_multiple_calls(values):
|
||||||
|
target_successes = 4
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, NUM_VALUES // target_successes):
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert (
|
||||||
|
len(value_list) == target_successes
|
||||||
|
), "list returned is based on successes"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# list all done but get list again
|
||||||
|
value_list = value_factory(successes=target_successes) # successes achieved
|
||||||
|
assert not value_list, "successes achieved"
|
||||||
|
|
||||||
|
# list all done but get list again
|
||||||
|
value_list = value_factory(
|
||||||
|
successes=1
|
||||||
|
) # not enough successes but list is now empty
|
||||||
|
assert not value_list, "successes not achieved, but no more values available"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_no_batching_no_success_multiple_calls_non_divisible_successes(
|
||||||
|
values,
|
||||||
|
):
|
||||||
|
target_successes = 6
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes
|
||||||
|
)
|
||||||
|
|
||||||
|
# should be able to get 4 lists
|
||||||
|
for i in range(0, NUM_VALUES // target_successes):
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert (
|
||||||
|
len(value_list) == target_successes
|
||||||
|
), "list returned is based on successes"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# last request
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert len(value_list) == NUM_VALUES % target_successes, "remaining list returned"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=target_successes) # successes achieved
|
||||||
|
assert not value_list, "successes achieved"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(
|
||||||
|
successes=target_successes - 1
|
||||||
|
) # not enough successes but list is now empty
|
||||||
|
assert not value_list, "successes not achieved, but no more values available"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_batching_individual(values):
|
||||||
|
target_successes = NUM_VALUES
|
||||||
|
batch_size = 1
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# number of successes returned since no batching provided
|
||||||
|
for i in range(0, NUM_VALUES // batch_size):
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert len(value_list) == batch_size, "list returned is based on batch size"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=NUM_VALUES) # successes achieved
|
||||||
|
assert not value_list, "successes achieved and no more values available"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=0) # successes not achieved
|
||||||
|
assert not value_list, "no successes achieved but no more values available"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_batching_divisible(values):
|
||||||
|
target_successes = NUM_VALUES
|
||||||
|
batch_size = 5
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# number of successes returned since no batching provided (3x here)
|
||||||
|
for i in range(0, NUM_VALUES // batch_size):
|
||||||
|
value_list = value_factory(successes=target_successes - 1)
|
||||||
|
assert len(value_list) == batch_size, "list returned is based on batch size"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=NUM_VALUES) # successes achieved
|
||||||
|
assert not value_list, "successes achieved and no more values available"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=0) # successes not achieved
|
||||||
|
assert not value_list, "no successes achieved but no more values available"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_value_factory_batching_non_divisible(values):
|
||||||
|
target_successes = NUM_VALUES
|
||||||
|
batch_size = 7
|
||||||
|
value_factory = BatchValueFactory(
|
||||||
|
values=values, required_successes=target_successes, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# number of successes returned since no batching provided
|
||||||
|
for i in range(0, NUM_VALUES // batch_size):
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert len(value_list) == batch_size, "list returned is based on batch size"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# one more
|
||||||
|
value_list = value_factory(successes=0)
|
||||||
|
assert len(value_list) == NUM_VALUES % batch_size, "remainder of list returned"
|
||||||
|
assert len(values) == NUM_VALUES, "values remained unchanged"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=target_successes) # successes achieved
|
||||||
|
assert not value_list, "successes achieved and no more values available"
|
||||||
|
|
||||||
|
# get list again
|
||||||
|
value_list = value_factory(successes=0) # successes not achieved
|
||||||
|
assert not value_list, "no successes achieved but no more values available"
|
Loading…
Reference in New Issue