From d0c25cf8d4d3e9c754212a0ae90a5b51fe6ea5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20N=C3=BA=C3=B1ez?= Date: Wed, 17 Jan 2024 15:35:12 +0100 Subject: [PATCH] Include batch_size as a parameter of ThresholdDecryptionRequestFactory --- nucypher/network/decryption.py | 19 ++++++++++++++----- nucypher/utilities/concurrency.py | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/nucypher/network/decryption.py b/nucypher/network/decryption.py index f8470b4d5..2ff401612 100644 --- a/nucypher/network/decryption.py +++ b/nucypher/network/decryption.py @@ -19,9 +19,17 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient): """Raised when a decryption request returns a non-zero status code.""" class ThresholdDecryptionRequestFactory(BatchValueFactory): - def __init__(self, ursulas_to_contact: List[ChecksumAddress], threshold: int): - # TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter - super().__init__(values=ursulas_to_contact, required_successes=threshold) + def __init__( + self, + ursulas_to_contact: List[ChecksumAddress], + threshold: int, + batch_size: int, + ): + super().__init__( + values=ursulas_to_contact, + required_successes=threshold, + batch_size=batch_size, + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -69,11 +77,12 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient): self.log.warn(message) raise self.ThresholdDecryptionRequestFailed(message) - batch_size = min(int(threshold * 1.25), len(encrypted_requests)) worker_pool = WorkerPool( worker=worker, value_factory=self.ThresholdDecryptionRequestFactory( - ursulas_to_contact=list(encrypted_requests.keys()), threshold=batch_size + ursulas_to_contact=list(encrypted_requests.keys()), + batch_size=int(threshold * 1.25), + threshold=threshold, ), target_successes=threshold, threadpool_size=len( diff --git a/nucypher/utilities/concurrency.py b/nucypher/utilities/concurrency.py index 4032a9819..edab187d3 100644 --- a/nucypher/utilities/concurrency.py +++ b/nucypher/utilities/concurrency.py @@ -354,7 +354,7 @@ class BatchValueFactory: if batch_size is not None and batch_size <= 0: raise ValueError(f"Invalid batch size specified ({batch_size})") - self.batch_size = batch_size if batch_size else required_successes + self.batch_size = batch_size or required_successes def __call__(self, successes) -> Optional[List[Any]]: if successes >= self.required_successes: