Merge pull request #3393 from cygnusv/nojoin

Optimize use of decryption request WorkerPool
pull/3294/head
KPrasch 2024-01-19 15:01:36 +01:00 committed by GitHub
commit 2ae6ce85b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 35 additions and 14 deletions

View File

@ -0,0 +1 @@
Optimize use of decryption request WorkerPool.

View File

@ -186,7 +186,7 @@ class RegistrySourceManager:
def __init__(
self,
domain: TACoDomain,
sources: Optional[RegistrySource] = None,
sources: Optional[List[RegistrySource]] = None,
only_primary: bool = False,
):
if only_primary and sources:

View File

@ -1,4 +1,5 @@
from http import HTTPStatus
from random import shuffle
from typing import Dict, List, Tuple
from eth_typing import ChecksumAddress
@ -12,15 +13,24 @@ from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
class ThresholdDecryptionClient(ThresholdAccessControlClient):
DEFAULT_DECRYPTION_TIMEOUT = 15
DEFAULT_DECRYPTION_TIMEOUT = 30
DEFAULT_STAGGER_TIMEOUT = 3
class ThresholdDecryptionRequestFailed(Exception):
"""Raised when a decryption request returns a non-zero status code."""
class ThresholdDecryptionRequestFactory(BatchValueFactory):
def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int):
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
super().__init__(values=ursula_to_contact, required_successes=threshold)
def __init__(
self,
ursulas_to_contact: List[ChecksumAddress],
threshold: int,
batch_size: int,
):
super().__init__(
values=ursulas_to_contact,
required_successes=threshold,
batch_size=batch_size,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -30,6 +40,7 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
encrypted_requests: Dict[ChecksumAddress, EncryptedThresholdDecryptionRequest],
threshold: int,
timeout: int = DEFAULT_DECRYPTION_TIMEOUT,
stagger_timeout: int = DEFAULT_STAGGER_TIMEOUT,
) -> Tuple[
Dict[ChecksumAddress, EncryptedThresholdDecryptionResponse],
Dict[ChecksumAddress, str],
@ -60,23 +71,30 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
response.content
)
except Exception as e:
self.log.warn(f"Node {ursula_address} raised {e}")
raise
message = f"Node {ursula_address} raised {e}"
self.log.warn(message)
raise self.ThresholdDecryptionRequestFailed(message)
message = f"Node {ursula_address} returned {response.status_code} - {response.content}."
self.log.warn(message)
raise self.ThresholdDecryptionRequestFailed(message)
# TODO: Find a better request order, perhaps based on latency data obtained from discovery loop - #3395
requests = list(encrypted_requests)
shuffle(requests)
# Discussion about WorkerPool parameters:
# "https://github.com/nucypher/nucypher/pull/3393#discussion_r1456307991"
worker_pool = WorkerPool(
worker=worker,
value_factory=self.ThresholdDecryptionRequestFactory(
ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold
ursulas_to_contact=requests,
batch_size=int(threshold * 1.25),
threshold=threshold,
),
target_successes=threshold,
threadpool_size=len(
encrypted_requests
), # TODO should we cap this (say 40?)
threadpool_size=int(threshold * 1.5), # TODO should we cap this (say 40?)
timeout=timeout,
stagger_timeout=stagger_timeout,
)
worker_pool.start()
try:
@ -86,7 +104,7 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
successes = worker_pool.get_successes()
finally:
worker_pool.cancel()
worker_pool.join()
failures = worker_pool.get_failures()
return successes, failures

View File

@ -662,7 +662,9 @@ class Learner:
self.log.info(f"Learned about enough nodes after {rounds_undertaken} rounds.")
return True
if not self._learning_task.running:
raise RuntimeError("Learning loop is not running. Start it with start_learning().")
raise RuntimeError(
"Learning loop is not running. Start it with start_learning_loop()."
)
elif not reactor.running and not learn_on_this_thread:
raise RuntimeError(
f"The reactor isn't running, but you're trying to use it for discovery. You need to start the Reactor in order to use {self} this way.")

View File

@ -354,7 +354,7 @@ class BatchValueFactory:
if batch_size is not None and batch_size <= 0:
raise ValueError(f"Invalid batch size specified ({batch_size})")
self.batch_size = batch_size if batch_size else required_successes
self.batch_size = batch_size or required_successes
def __call__(self, successes) -> Optional[List[Any]]:
if successes >= self.required_successes: