Include batch_size as a parameter of ThresholdDecryptionRequestFactory

pull/3393/head
David Núñez 2024-01-17 15:35:12 +01:00
parent 7e8caefabb
commit d0c25cf8d4
No known key found for this signature in database
GPG Key ID: 53A9D83EF4C6332A
2 changed files with 15 additions and 6 deletions

View File

@ -19,9 +19,17 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
"""Raised when a decryption request returns a non-zero status code."""
class ThresholdDecryptionRequestFactory(BatchValueFactory):
def __init__(self, ursulas_to_contact: List[ChecksumAddress], threshold: int):
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
super().__init__(values=ursulas_to_contact, required_successes=threshold)
def __init__(
self,
ursulas_to_contact: List[ChecksumAddress],
threshold: int,
batch_size: int,
):
super().__init__(
values=ursulas_to_contact,
required_successes=threshold,
batch_size=batch_size,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -69,11 +77,12 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
self.log.warn(message)
raise self.ThresholdDecryptionRequestFailed(message)
batch_size = min(int(threshold * 1.25), len(encrypted_requests))
worker_pool = WorkerPool(
worker=worker,
value_factory=self.ThresholdDecryptionRequestFactory(
ursulas_to_contact=list(encrypted_requests.keys()), threshold=batch_size
ursulas_to_contact=list(encrypted_requests.keys()),
batch_size=int(threshold * 1.25),
threshold=threshold,
),
target_successes=threshold,
threadpool_size=len(

View File

@ -354,7 +354,7 @@ class BatchValueFactory:
if batch_size is not None and batch_size <= 0:
raise ValueError(f"Invalid batch size specified ({batch_size})")
self.batch_size = batch_size if batch_size else required_successes
self.batch_size = batch_size or required_successes
def __call__(self, successes) -> Optional[List[Any]]:
if successes >= self.required_successes: