mirror of https://github.com/nucypher/nucypher.git
Include batch_size as a parameter of ThresholdDecryptionRequestFactory
parent
7e8caefabb
commit
d0c25cf8d4
|
@ -19,9 +19,17 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
"""Raised when a decryption request returns a non-zero status code."""
|
||||
|
||||
class ThresholdDecryptionRequestFactory(BatchValueFactory):
|
||||
def __init__(self, ursulas_to_contact: List[ChecksumAddress], threshold: int):
|
||||
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
|
||||
super().__init__(values=ursulas_to_contact, required_successes=threshold)
|
||||
def __init__(
|
||||
self,
|
||||
ursulas_to_contact: List[ChecksumAddress],
|
||||
threshold: int,
|
||||
batch_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
values=ursulas_to_contact,
|
||||
required_successes=threshold,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -69,11 +77,12 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
self.log.warn(message)
|
||||
raise self.ThresholdDecryptionRequestFailed(message)
|
||||
|
||||
batch_size = min(int(threshold * 1.25), len(encrypted_requests))
|
||||
worker_pool = WorkerPool(
|
||||
worker=worker,
|
||||
value_factory=self.ThresholdDecryptionRequestFactory(
|
||||
ursulas_to_contact=list(encrypted_requests.keys()), threshold=batch_size
|
||||
ursulas_to_contact=list(encrypted_requests.keys()),
|
||||
batch_size=int(threshold * 1.25),
|
||||
threshold=threshold,
|
||||
),
|
||||
target_successes=threshold,
|
||||
threadpool_size=len(
|
||||
|
|
|
@ -354,7 +354,7 @@ class BatchValueFactory:
|
|||
|
||||
if batch_size is not None and batch_size <= 0:
|
||||
raise ValueError(f"Invalid batch size specified ({batch_size})")
|
||||
self.batch_size = batch_size if batch_size else required_successes
|
||||
self.batch_size = batch_size or required_successes
|
||||
|
||||
def __call__(self, successes) -> Optional[List[Any]]:
|
||||
if successes >= self.required_successes:
|
||||
|
|
Loading…
Reference in New Issue