mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3393 from cygnusv/nojoin
Optimize use of decryption request WorkerPoolpull/3294/head
commit
2ae6ce85b0
|
@ -0,0 +1 @@
|
|||
Optimize use of decryption request WorkerPool.
|
|
@ -186,7 +186,7 @@ class RegistrySourceManager:
|
|||
def __init__(
|
||||
self,
|
||||
domain: TACoDomain,
|
||||
sources: Optional[RegistrySource] = None,
|
||||
sources: Optional[List[RegistrySource]] = None,
|
||||
only_primary: bool = False,
|
||||
):
|
||||
if only_primary and sources:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from http import HTTPStatus
|
||||
from random import shuffle
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
|
@ -12,15 +13,24 @@ from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
|
|||
|
||||
|
||||
class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
||||
DEFAULT_DECRYPTION_TIMEOUT = 15
|
||||
DEFAULT_DECRYPTION_TIMEOUT = 30
|
||||
DEFAULT_STAGGER_TIMEOUT = 3
|
||||
|
||||
class ThresholdDecryptionRequestFailed(Exception):
|
||||
"""Raised when a decryption request returns a non-zero status code."""
|
||||
|
||||
class ThresholdDecryptionRequestFactory(BatchValueFactory):
|
||||
def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int):
|
||||
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
|
||||
super().__init__(values=ursula_to_contact, required_successes=threshold)
|
||||
def __init__(
|
||||
self,
|
||||
ursulas_to_contact: List[ChecksumAddress],
|
||||
threshold: int,
|
||||
batch_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
values=ursulas_to_contact,
|
||||
required_successes=threshold,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -30,6 +40,7 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
encrypted_requests: Dict[ChecksumAddress, EncryptedThresholdDecryptionRequest],
|
||||
threshold: int,
|
||||
timeout: int = DEFAULT_DECRYPTION_TIMEOUT,
|
||||
stagger_timeout: int = DEFAULT_STAGGER_TIMEOUT,
|
||||
) -> Tuple[
|
||||
Dict[ChecksumAddress, EncryptedThresholdDecryptionResponse],
|
||||
Dict[ChecksumAddress, str],
|
||||
|
@ -60,23 +71,30 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
response.content
|
||||
)
|
||||
except Exception as e:
|
||||
self.log.warn(f"Node {ursula_address} raised {e}")
|
||||
raise
|
||||
message = f"Node {ursula_address} raised {e}"
|
||||
self.log.warn(message)
|
||||
raise self.ThresholdDecryptionRequestFailed(message)
|
||||
|
||||
message = f"Node {ursula_address} returned {response.status_code} - {response.content}."
|
||||
self.log.warn(message)
|
||||
raise self.ThresholdDecryptionRequestFailed(message)
|
||||
|
||||
# TODO: Find a better request order, perhaps based on latency data obtained from discovery loop - #3395
|
||||
requests = list(encrypted_requests)
|
||||
shuffle(requests)
|
||||
# Discussion about WorkerPool parameters:
|
||||
# "https://github.com/nucypher/nucypher/pull/3393#discussion_r1456307991"
|
||||
worker_pool = WorkerPool(
|
||||
worker=worker,
|
||||
value_factory=self.ThresholdDecryptionRequestFactory(
|
||||
ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold
|
||||
ursulas_to_contact=requests,
|
||||
batch_size=int(threshold * 1.25),
|
||||
threshold=threshold,
|
||||
),
|
||||
target_successes=threshold,
|
||||
threadpool_size=len(
|
||||
encrypted_requests
|
||||
), # TODO should we cap this (say 40?)
|
||||
threadpool_size=int(threshold * 1.5), # TODO should we cap this (say 40?)
|
||||
timeout=timeout,
|
||||
stagger_timeout=stagger_timeout,
|
||||
)
|
||||
worker_pool.start()
|
||||
try:
|
||||
|
@ -86,7 +104,7 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
successes = worker_pool.get_successes()
|
||||
finally:
|
||||
worker_pool.cancel()
|
||||
worker_pool.join()
|
||||
|
||||
failures = worker_pool.get_failures()
|
||||
|
||||
return successes, failures
|
||||
|
|
|
@ -662,7 +662,9 @@ class Learner:
|
|||
self.log.info(f"Learned about enough nodes after {rounds_undertaken} rounds.")
|
||||
return True
|
||||
if not self._learning_task.running:
|
||||
raise RuntimeError("Learning loop is not running. Start it with start_learning().")
|
||||
raise RuntimeError(
|
||||
"Learning loop is not running. Start it with start_learning_loop()."
|
||||
)
|
||||
elif not reactor.running and not learn_on_this_thread:
|
||||
raise RuntimeError(
|
||||
f"The reactor isn't running, but you're trying to use it for discovery. You need to start the Reactor in order to use {self} this way.")
|
||||
|
|
|
@ -354,7 +354,7 @@ class BatchValueFactory:
|
|||
|
||||
if batch_size is not None and batch_size <= 0:
|
||||
raise ValueError(f"Invalid batch size specified ({batch_size})")
|
||||
self.batch_size = batch_size if batch_size else required_successes
|
||||
self.batch_size = batch_size or required_successes
|
||||
|
||||
def __call__(self, successes) -> Optional[List[Any]]:
|
||||
if successes >= self.required_successes:
|
||||
|
|
Loading…
Reference in New Issue