mirror of https://github.com/nucypher/nucypher.git
Generalize NodeEngagementMutex
parent
a2b99daa1d
commit
bbc4390f68
|
@ -102,7 +102,7 @@ policy = ALICE.grant(BOB,
|
|||
expiration=policy_end_datetime)
|
||||
|
||||
assert policy.public_key == policy_pubkey
|
||||
policy.publishing_mutex.block_until_complete()
|
||||
policy.treasure_map_publisher.block_until_complete()
|
||||
|
||||
# Alice puts her public key somewhere for Bob to find later...
|
||||
alices_pubkey_bytes_saved_for_posterity = bytes(ALICE.stamp)
|
||||
|
|
|
@ -141,7 +141,7 @@ policy = alicia.grant(bob=doctor_strange,
|
|||
m=m,
|
||||
n=n,
|
||||
expiration=policy_end_datetime)
|
||||
policy.publishing_mutex.block_until_complete()
|
||||
policy.treasure_map_publisher.block_until_complete()
|
||||
print("Done!")
|
||||
|
||||
# For the demo, we need a way to share with Bob some additional info
|
||||
|
|
|
@ -123,7 +123,7 @@ class AliceInterface(CharacterPublicInterface):
|
|||
expiration=expiration,
|
||||
discover_on_this_thread=True)
|
||||
|
||||
new_policy.publishing_mutex.block_until_success_is_reasonably_likely()
|
||||
new_policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
|
||||
|
||||
response_data = {'treasure_map': new_policy.treasure_map,
|
||||
'policy_encrypting_key': new_policy.public_key,
|
||||
|
|
|
@ -351,7 +351,7 @@ class Alice(Character, BlockchainPolicyAuthor):
|
|||
self.add_active_policy(enacted_policy)
|
||||
|
||||
if publish_treasure_map and block_until_success_is_reasonably_likely:
|
||||
enacted_policy.publishing_mutex.block_until_success_is_reasonably_likely()
|
||||
enacted_policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
|
||||
return enacted_policy
|
||||
|
||||
def get_policy_encrypting_key_from_label(self, label: bytes) -> UmbralPublicKey:
|
||||
|
|
|
@ -47,6 +47,7 @@ from nucypher.crypto.powers import DecryptingPower, SigningPower, TransactingPow
|
|||
from nucypher.crypto.utils import construct_policy_id
|
||||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||
from nucypher.network.middleware import RestMiddleware
|
||||
from nucypher.utilities.concurrency import WorkerPool, AllAtOnceFactory
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
|
||||
|
@ -90,151 +91,48 @@ class Arrangement:
|
|||
return f"Arrangement(client_key={self.alice_verifying_key})"
|
||||
|
||||
|
||||
class NodeEngagementMutex:
|
||||
"""
|
||||
TODO: Does this belong on middleware?
|
||||
class TreasureMapPublisher:
|
||||
|
||||
TODO: There are a couple of ways this can break. If one fo the jobs hangs, the whole thing will hang. Also,
|
||||
if there are fewer successfully completed than percent_to_complete_before_release, the partial queue will never
|
||||
release.
|
||||
|
||||
TODO: Make registry per... I guess Policy? It's weird to be able to accidentally enact again.
|
||||
"""
|
||||
log = Logger("Policy")
|
||||
log = Logger('TreasureMapPublisher')
|
||||
|
||||
def __init__(self,
|
||||
callable_to_engage, # TODO: typing.Protocol
|
||||
worker,
|
||||
nodes,
|
||||
network_middleware,
|
||||
percent_to_complete_before_release=5,
|
||||
note=None,
|
||||
threadpool_size=120,
|
||||
timeout=20,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.f = callable_to_engage
|
||||
self.nodes = nodes
|
||||
self.network_middleware = network_middleware
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
timeout=20):
|
||||
|
||||
self.completed = {}
|
||||
self.failed = {}
|
||||
self._total = len(nodes)
|
||||
self._block_until_this_many_are_complete = math.ceil(len(nodes) * percent_to_complete_before_release / 100)
|
||||
self._worker_pool = WorkerPool(worker=worker,
|
||||
value_factory=AllAtOnceFactory(nodes),
|
||||
target_successes=self._block_until_this_many_are_complete,
|
||||
timeout=timeout,
|
||||
stagger_timeout=0,
|
||||
threadpool_size=threadpool_size)
|
||||
|
||||
self._started = False
|
||||
self._finished = False
|
||||
self.timeout = timeout
|
||||
|
||||
self.percent_to_complete_before_release = percent_to_complete_before_release
|
||||
self._partial_queue = Queue()
|
||||
self._completion_queue = Queue()
|
||||
self._block_until_this_many_are_complete = math.ceil(
|
||||
len(nodes) * self.percent_to_complete_before_release / 100)
|
||||
self.nodes_contacted_during_partial_block = False
|
||||
self.when_complete = Deferred() # TODO: Allow cancelling via KB Interrupt or some other way?
|
||||
|
||||
if note is None:
|
||||
self._repr = f"{callable_to_engage} to {len(nodes)} nodes"
|
||||
else:
|
||||
self._repr = f"{note}: {callable_to_engage} to {len(nodes)} nodes"
|
||||
|
||||
self._threadpool = ThreadPool(minthreads=threadpool_size, maxthreads=threadpool_size, name=self._repr)
|
||||
self.log.info(f"NEM spinning up {self._threadpool}")
|
||||
self._threadpool.callInThread(self._bail_on_timeout)
|
||||
|
||||
def __repr__(self):
|
||||
return self._repr
|
||||
|
||||
def _bail_on_timeout(self):
|
||||
while True:
|
||||
if self.when_complete.called:
|
||||
return
|
||||
duration = datetime.datetime.now() - self._started
|
||||
if duration.seconds >= self.timeout:
|
||||
try:
|
||||
self._threadpool.stop()
|
||||
except AlreadyQuit:
|
||||
raise RuntimeError("Is there a race condition here? If this line is being hit, it's a bug.")
|
||||
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
|
||||
time.sleep(.5)
|
||||
|
||||
def block_until_success_is_reasonably_likely(self):
|
||||
"""
|
||||
https://www.youtube.com/watch?v=OkSLswPSq2o
|
||||
"""
|
||||
if len(self.completed) < self._block_until_this_many_are_complete:
|
||||
try:
|
||||
completed_for_reasonable_likelihood_of_success = self._partial_queue.get(timeout=self.timeout) # TODO: Shorter timeout here?
|
||||
except Empty:
|
||||
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
|
||||
self.log.debug(f"{len(self.completed)} nodes were contacted while blocking for a little while.")
|
||||
return completed_for_reasonable_likelihood_of_success
|
||||
else:
|
||||
return self.completed
|
||||
|
||||
|
||||
def block_until_complete(self):
|
||||
if self.total_disposed() < len(self.nodes):
|
||||
try:
|
||||
_ = self._completion_queue.get(timeout=self.timeout) # Interesting opportuntiy to pass some data, like the list of contacted nodes above.
|
||||
except Empty:
|
||||
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
|
||||
if not reactor.running and not self._threadpool.joined:
|
||||
# If the reactor isn't running, the user *must* call this, because this is where we stop.
|
||||
self._threadpool.stop()
|
||||
|
||||
def _handle_success(self, response, node):
|
||||
if response.status_code == 201:
|
||||
self.completed[node] = response
|
||||
else:
|
||||
assert False # TODO: What happens if this is a 300 or 400 level response? (A 500 response will propagate as an error and be handled in the errback chain.)
|
||||
if self.nodes_contacted_during_partial_block:
|
||||
self._consider_finalizing()
|
||||
else:
|
||||
if len(self.completed) >= self._block_until_this_many_are_complete:
|
||||
contacted = tuple(self.completed.keys())
|
||||
self.nodes_contacted_during_partial_block = contacted
|
||||
self.log.debug(f"Blocked for a little while, completed {contacted} nodes")
|
||||
self._partial_queue.put(contacted)
|
||||
return response
|
||||
|
||||
def _handle_error(self, failure, node):
|
||||
self.failed[node] = failure # TODO: Add a failfast mode?
|
||||
self._consider_finalizing()
|
||||
self.log.warn(f"{node} failed: {failure}")
|
||||
|
||||
def total_disposed(self):
|
||||
return len(self.completed) + len(self.failed)
|
||||
|
||||
def _consider_finalizing(self):
|
||||
if not self._finished:
|
||||
if self.total_disposed() == len(self.nodes):
|
||||
# TODO: Consider whether this can possibly hang.
|
||||
self._finished = True
|
||||
if reactor.running:
|
||||
reactor.callInThread(self._threadpool.stop)
|
||||
self._completion_queue.put(self.completed)
|
||||
self.when_complete.callback(self.completed)
|
||||
self.log.info(f"{self} finished.")
|
||||
else:
|
||||
raise RuntimeError("Already finished.")
|
||||
|
||||
def _engage_node(self, node):
|
||||
maybe_coro = self.f(node, network_middleware=self.network_middleware, *self.args, **self.kwargs)
|
||||
|
||||
d = ensureDeferred(maybe_coro)
|
||||
d.addCallback(self._handle_success, node)
|
||||
d.addErrback(self._handle_error, node)
|
||||
return d
|
||||
@property
|
||||
def completed(self):
|
||||
# TODO: lock dict before copying?
|
||||
return self._worker_pool.get_successes()
|
||||
|
||||
def start(self):
|
||||
if self._started:
|
||||
raise RuntimeError("Already started.")
|
||||
self._started = datetime.datetime.now()
|
||||
self.log.info(f"NEM Starting {self._threadpool}")
|
||||
for node in self.nodes:
|
||||
self._threadpool.callInThread(self._engage_node, node)
|
||||
self._threadpool.start()
|
||||
self.log.info(f"TreasureMapPublisher starting")
|
||||
self._worker_pool.start()
|
||||
if reactor.running:
|
||||
reactor.callInThread(self.block_until_complete)
|
||||
|
||||
def block_until_success_is_reasonably_likely(self):
|
||||
# Note: `OutOfValues`/`TimedOut` may be raised here, which means we didn't even get to
|
||||
# `percent_to_complete_before_release` successes. For now just letting it fire.
|
||||
self._worker_pool.block_until_target_successes()
|
||||
completed = self.completed
|
||||
self.log.debug(f"The minimal amount of nodes ({len(completed)}) was contacted "
|
||||
"while blocking for treasure map publication.")
|
||||
return completed
|
||||
|
||||
def block_until_complete(self):
|
||||
self._worker_pool.join()
|
||||
|
||||
|
||||
class MergedReservoir:
|
||||
|
@ -534,23 +432,34 @@ class Policy(ABC):
|
|||
|
||||
return treasure_map
|
||||
|
||||
def _make_publishing_mutex(self,
|
||||
treasure_map: 'TreasureMap',
|
||||
network_middleware: RestMiddleware,
|
||||
) -> NodeEngagementMutex:
|
||||
|
||||
async def put_treasure_map_on_node(node, network_middleware):
|
||||
response = network_middleware.put_treasure_map_on_node(node=node,
|
||||
map_payload=bytes(treasure_map))
|
||||
return response
|
||||
def _make_publisher(self,
|
||||
treasure_map: 'TreasureMap',
|
||||
network_middleware: RestMiddleware,
|
||||
) -> TreasureMapPublisher:
|
||||
|
||||
# TODO (#2516): remove hardcoding of 8 nodes
|
||||
self.alice.block_until_number_of_known_nodes_is(8, timeout=2, learn_on_this_thread=True)
|
||||
target_nodes = self.bob.matching_nodes_among(self.alice.known_nodes)
|
||||
treasure_map_bytes = bytes(treasure_map) # prevent the closure from holding the reference
|
||||
|
||||
return NodeEngagementMutex(callable_to_engage=put_treasure_map_on_node,
|
||||
nodes=target_nodes,
|
||||
network_middleware=network_middleware)
|
||||
def put_treasure_map_on_node(node):
|
||||
try:
|
||||
response = network_middleware.put_treasure_map_on_node(node=node,
|
||||
map_payload=treasure_map_bytes)
|
||||
except Exception as e:
|
||||
self.log.warn(f"Putting treasure map on {node} failed: {e}")
|
||||
raise
|
||||
|
||||
if response.status_code == 201:
|
||||
return response
|
||||
else:
|
||||
message = f"Putting treasure map on {node} failed with response status: {response.status}"
|
||||
self.log.warn(message)
|
||||
# TODO: What happens if this is a 300 or 400 level response?
|
||||
raise Exception(message)
|
||||
|
||||
return TreasureMapPublisher(worker=put_treasure_map_on_node,
|
||||
nodes=target_nodes)
|
||||
|
||||
def enact(self,
|
||||
network_middleware: RestMiddleware,
|
||||
|
@ -572,8 +481,8 @@ class Policy(ABC):
|
|||
|
||||
treasure_map = self._make_treasure_map(network_middleware=network_middleware,
|
||||
arrangements=arrangements)
|
||||
publishing_mutex = self._make_publishing_mutex(treasure_map=treasure_map,
|
||||
network_middleware=network_middleware)
|
||||
treasure_map_publisher = self._make_publisher(treasure_map=treasure_map,
|
||||
network_middleware=network_middleware)
|
||||
revocation_kit = RevocationKit(treasure_map, self.alice.stamp)
|
||||
|
||||
enacted_policy = EnactedPolicy(self._id,
|
||||
|
@ -581,7 +490,7 @@ class Policy(ABC):
|
|||
self.label,
|
||||
self.public_key,
|
||||
treasure_map,
|
||||
publishing_mutex,
|
||||
treasure_map_publisher,
|
||||
revocation_kit,
|
||||
self.alice.stamp)
|
||||
|
||||
|
@ -730,7 +639,7 @@ class BlockchainPolicy(Policy):
|
|||
def _enact_arrangements(self,
|
||||
network_middleware,
|
||||
arrangements,
|
||||
publish_treasure_map=True) -> NodeEngagementMutex:
|
||||
publish_treasure_map=True) -> TreasureMapPublisher:
|
||||
transaction = self._publish_to_blockchain(list(arrangements))
|
||||
return super()._enact_arrangements(network_middleware=network_middleware,
|
||||
arrangements=arrangements,
|
||||
|
@ -756,7 +665,7 @@ class EnactedPolicy:
|
|||
label: bytes,
|
||||
public_key: UmbralPublicKey,
|
||||
treasure_map: 'TreasureMap',
|
||||
publishing_mutex: NodeEngagementMutex,
|
||||
treasure_map_publisher: TreasureMapPublisher,
|
||||
revocation_kit: RevocationKit,
|
||||
alice_verifying_key: UmbralPublicKey,
|
||||
):
|
||||
|
@ -766,10 +675,10 @@ class EnactedPolicy:
|
|||
self.label = label
|
||||
self.public_key = public_key
|
||||
self.treasure_map = treasure_map
|
||||
self.publishing_mutex = publishing_mutex
|
||||
self.treasure_map_publisher = treasure_map_publisher
|
||||
self.revocation_kit = revocation_kit
|
||||
self.n = len(self.treasure_map.destinations)
|
||||
self.alice_verifying_key = alice_verifying_key
|
||||
|
||||
def publish_treasure_map(self):
|
||||
self.publishing_mutex.start()
|
||||
self.treasure_map_publisher.start()
|
||||
|
|
|
@ -0,0 +1,304 @@
|
|||
"""
|
||||
This file is part of nucypher.
|
||||
|
||||
nucypher is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
nucypher is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import time
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event, Lock, Timer, get_ident
|
||||
from typing import Callable, List, Any, Optional, Dict
|
||||
|
||||
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
|
||||
from twisted._threads import AlreadyQuit
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
|
||||
class Success:
|
||||
def __init__(self, value, result):
|
||||
self.value = value
|
||||
self.result = result
|
||||
|
||||
class Failure:
|
||||
def __init__(self, value, exception):
|
||||
self.value = value
|
||||
self.exception = exception
|
||||
|
||||
|
||||
class Cancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SetOnce:
|
||||
"""
|
||||
A convenience wrapper for a value that can be set once (which can be waited on),
|
||||
and cannot be overwritten (unless cleared).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = Lock()
|
||||
self._set_event = Event()
|
||||
self._value = None
|
||||
|
||||
def set(self, value):
|
||||
with self._lock:
|
||||
if not self._set_event.is_set():
|
||||
self._value = value
|
||||
self._set_event.set()
|
||||
|
||||
def is_set(self):
|
||||
return self._set_event.is_set()
|
||||
|
||||
def get_and_clear(self):
|
||||
with self._lock:
|
||||
value = self._value
|
||||
self._value = None
|
||||
self._set_event.clear()
|
||||
return value
|
||||
|
||||
def get(self):
|
||||
self._set_event.wait()
|
||||
return self._value
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""
|
||||
A generalized class that can start multiple workers in a thread pool with values
|
||||
drawn from the given value factory object,
|
||||
and wait for their completion and a given number of successes
|
||||
(a worker returning something without throwing an exception).
|
||||
"""
|
||||
|
||||
class TimedOut(Exception):
|
||||
"Raised if waiting for the target number of successes timed out."
|
||||
|
||||
class OutOfValues(Exception):
|
||||
"Raised if the value factory is out of values, but the target number was not reached."
|
||||
|
||||
def __init__(self,
|
||||
worker: Callable[[Any], Any],
|
||||
value_factory: Callable[[int], Optional[List[Any]]],
|
||||
target_successes,
|
||||
timeout: float,
|
||||
stagger_timeout: float = 0,
|
||||
threadpool_size: int = None):
|
||||
|
||||
# TODO: make stagger_timeout a part of the value factory?
|
||||
|
||||
self._worker = worker
|
||||
self._value_factory = value_factory
|
||||
self._timeout = timeout
|
||||
self._stagger_timeout = stagger_timeout
|
||||
self._target_successes = target_successes
|
||||
|
||||
thread_pool_kwargs = {}
|
||||
if threadpool_size is not None:
|
||||
thread_pool_kwargs['minthreads'] = threadpool_size
|
||||
thread_pool_kwargs['maxthreads'] = threadpool_size
|
||||
self._threadpool = ThreadPool(**thread_pool_kwargs)
|
||||
|
||||
# These three tasks must be run in separate threads
|
||||
# to avoid being blocked by workers in the thread pool.
|
||||
self._bail_on_timeout_thread = Thread(target=self._bail_on_timeout)
|
||||
self._produce_values_thread = Thread(target=self._produce_values)
|
||||
self._process_results_thread = Thread(target=self._process_results)
|
||||
|
||||
self._successes = {}
|
||||
self._failures = {}
|
||||
self._started_tasks = 0
|
||||
self._finished_tasks = 0
|
||||
|
||||
self._cancel_event = Event()
|
||||
self._result_queue = Queue()
|
||||
self._target_value = SetOnce()
|
||||
self._unexpected_error = SetOnce()
|
||||
self._results_lock = Lock()
|
||||
self._stopped = False
|
||||
|
||||
def start(self):
|
||||
# TODO: check if already started?
|
||||
self._threadpool.start()
|
||||
self._produce_values_thread.start()
|
||||
self._process_results_thread.start()
|
||||
self._bail_on_timeout_thread.start()
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Cancels the tasks enqueued in the thread pool and stops the producer thread.
|
||||
"""
|
||||
self._cancel_event.set()
|
||||
|
||||
def join(self):
|
||||
"""
|
||||
Waits for all the threads to finish.
|
||||
Can be called several times.
|
||||
"""
|
||||
|
||||
if self._stopped:
|
||||
return # or raise AlreadyStopped?
|
||||
|
||||
self._produce_values_thread.join()
|
||||
self._process_results_thread.join()
|
||||
self._bail_on_timeout_thread.join()
|
||||
|
||||
# protect from a possible race
|
||||
try:
|
||||
self._threadpool.stop()
|
||||
except AlreadyQuit:
|
||||
pass
|
||||
self._stopped = True
|
||||
|
||||
if self._unexpected_error.is_set():
|
||||
e = self._unexpected_error.get()
|
||||
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
|
||||
|
||||
def _sleep(self, timeout):
|
||||
"""
|
||||
Sleeps for a given timeout, can be interrupted by a cancellation event.
|
||||
"""
|
||||
if self._cancel_event.wait(timeout):
|
||||
raise Cancelled
|
||||
|
||||
def block_until_target_successes(self) -> Dict:
|
||||
"""
|
||||
Blocks until the target number of successes is reached.
|
||||
Returns a dictionary of values matched to results.
|
||||
Can be called several times.
|
||||
"""
|
||||
if self._unexpected_error.is_set():
|
||||
# So that we don't raise it again when join() is called
|
||||
e = self._unexpected_error.get_and_clear()
|
||||
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
|
||||
|
||||
result = self._target_value.get()
|
||||
if result == TIMEOUT_TRIGGERED:
|
||||
raise self.TimedOut()
|
||||
elif result == PRODUCER_STOPPED:
|
||||
raise self.OutOfValues()
|
||||
return result
|
||||
|
||||
def get_failures(self) -> Dict:
|
||||
"""
|
||||
Get the current failures, as a dictionary of values to thrown exceptions.
|
||||
"""
|
||||
with self._results_lock:
|
||||
return dict(self._failures)
|
||||
|
||||
def get_successes(self) -> Dict:
|
||||
"""
|
||||
Get the current successes, as a dictionary of values to worker return values.
|
||||
"""
|
||||
with self._results_lock:
|
||||
return dict(self._successes)
|
||||
|
||||
def _bail_on_timeout(self):
|
||||
"""
|
||||
A service thread that cancels the pool on timeout.
|
||||
"""
|
||||
if not self._cancel_event.wait(timeout=self._timeout):
|
||||
self._target_value.set(TIMEOUT_TRIGGERED)
|
||||
self._cancel_event.set()
|
||||
|
||||
def _worker_wrapper(self, value):
|
||||
"""
|
||||
A wrapper that catches exceptions thrown by the worker
|
||||
and sends the results to the processing thread.
|
||||
"""
|
||||
try:
|
||||
# If we're in the cancelled state, interrupt early
|
||||
self._sleep(0)
|
||||
|
||||
result = self._worker(value)
|
||||
self._result_queue.put(Success(value, result))
|
||||
except Cancelled as e:
|
||||
self._result_queue.put(e)
|
||||
except BaseException as e:
|
||||
self._result_queue.put(Failure(value, str(e)))
|
||||
|
||||
def _process_results(self):
|
||||
"""
|
||||
A service thread that processes worker results
|
||||
and waits for the target number of successes to be reached.
|
||||
"""
|
||||
producer_stopped = False
|
||||
success_event_reached = False
|
||||
while True:
|
||||
result = self._result_queue.get()
|
||||
|
||||
if result == PRODUCER_STOPPED:
|
||||
producer_stopped = True
|
||||
else:
|
||||
self._finished_tasks += 1
|
||||
if isinstance(result, Success):
|
||||
with self._results_lock:
|
||||
self._successes[result.value] = result.result
|
||||
len_successes = len(self._successes)
|
||||
if not success_event_reached and len_successes == self._target_successes:
|
||||
# A protection for the case of repeating values.
|
||||
# Only trigger the target value once.
|
||||
success_event_reached = True
|
||||
self._target_value.set(self.get_successes())
|
||||
if isinstance(result, Failure):
|
||||
with self._results_lock:
|
||||
self._failures[result.value] = result.exception
|
||||
|
||||
if producer_stopped and self._finished_tasks == self._started_tasks:
|
||||
self.cancel() # to cancel the timeout thread
|
||||
self._target_value.set(PRODUCER_STOPPED)
|
||||
break
|
||||
|
||||
def _produce_values(self):
|
||||
while True:
|
||||
try:
|
||||
with self._results_lock:
|
||||
len_successes = len(self._successes)
|
||||
batch = self._value_factory(len_successes)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
self._started_tasks += len(batch)
|
||||
for value in batch:
|
||||
# There is a possible race between `callInThread()` and `stop()`,
|
||||
# But we never execute them at the same time,
|
||||
# because `join()` checks that the producer thread is stopped.
|
||||
self._threadpool.callInThread(self._worker_wrapper, value)
|
||||
|
||||
self._sleep(self._stagger_timeout)
|
||||
|
||||
except Cancelled:
|
||||
break
|
||||
|
||||
except BaseException as e:
|
||||
self._unexpected_error.set(e)
|
||||
self.cancel()
|
||||
break
|
||||
|
||||
self._result_queue.put(PRODUCER_STOPPED)
|
||||
|
||||
|
||||
class AllAtOnceFactory:
|
||||
"""
|
||||
A simple value factory that returns all its values in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
self._produced = False
|
||||
|
||||
def __call__(self, _successes):
|
||||
if self._produced:
|
||||
return None
|
||||
else:
|
||||
self._produced = True
|
||||
return self.values
|
|
@ -168,8 +168,8 @@ def test_treasure_map_cannot_be_duplicated(blockchain_ursulas,
|
|||
expiration=policy_end_datetime)
|
||||
|
||||
matching_ursulas = blockchain_bob.matching_nodes_among(blockchain_ursulas)
|
||||
completed_ursulas = policy.publishing_mutex.block_until_success_is_reasonably_likely()
|
||||
# Ursulas in publishing_mutex are not real Ursulas, but just some metadata of remote ones.
|
||||
completed_ursulas = policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
|
||||
# Ursulas in `treasure_map_publisher` are not real Ursulas, but just some metadata of remote ones.
|
||||
# We need a real one to access its datastore.
|
||||
first_completed_ursula = [ursula for ursula in matching_ursulas if ursula in completed_ursulas][0]
|
||||
|
||||
|
|
|
@ -0,0 +1,331 @@
|
|||
import random
|
||||
import time
|
||||
from typing import Iterable, Tuple, List, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from nucypher.utilities.concurrency import WorkerPool, AllAtOnceFactory
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def join_worker_pool(request):
|
||||
"""
|
||||
Makes sure the pool is properly joined at the end of the test,
|
||||
so that one doesn't have to wrap the whole test in a try-finally block.
|
||||
"""
|
||||
pool_to_join = None
|
||||
def register(pool):
|
||||
nonlocal pool_to_join
|
||||
pool_to_join = pool
|
||||
yield register
|
||||
pool_to_join.join()
|
||||
|
||||
|
||||
class WorkerRule:
|
||||
def __init__(self, fails: bool = False, timeout_min: float = 0, timeout_max: float = 0):
|
||||
self.fails = fails
|
||||
self.timeout_min = timeout_min
|
||||
self.timeout_max = timeout_max
|
||||
|
||||
|
||||
class WorkerOutcome:
|
||||
def __init__(self, fails: bool, timeout: float):
|
||||
self.fails = fails
|
||||
self.timeout = timeout
|
||||
|
||||
def __call__(self, value):
|
||||
time.sleep(self.timeout)
|
||||
if self.fails:
|
||||
raise Exception(f"Worker for {value} failed")
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def generate_workers(rules: Iterable[Tuple[WorkerRule, int]], seed=None):
|
||||
rng = random.Random(seed)
|
||||
outcomes = []
|
||||
for rule, quantity in rules:
|
||||
for _ in range(quantity):
|
||||
timeout = rng.uniform(rule.timeout_min, rule.timeout_max)
|
||||
outcomes.append(WorkerOutcome(rule.fails, timeout))
|
||||
|
||||
rng.shuffle(outcomes)
|
||||
|
||||
values = list(range(len(outcomes)))
|
||||
|
||||
def worker(value):
|
||||
return outcomes[value](value)
|
||||
|
||||
return {value: outcomes[value] for value in values}, worker
|
||||
|
||||
|
||||
def test_wait_for_successes(join_worker_pool):
|
||||
"""
|
||||
Checks that `block_until_target_successes()` returns in time and gives all the successes,
|
||||
if there were enough of them.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 10),
|
||||
(WorkerRule(fails=True, timeout_min=1, timeout_max=3), 20),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = AllAtOnceFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=30)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
successes = pool.block_until_target_successes()
|
||||
t_end = time.monotonic()
|
||||
|
||||
failures = pool.get_failures()
|
||||
assert all(outcomes[value].fails for value in failures)
|
||||
|
||||
assert len(successes) == 10
|
||||
|
||||
# We have more threads in the pool than the workers,
|
||||
# so all the successful ones should be able to finish right away.
|
||||
assert t_end - t_start < 2
|
||||
|
||||
# Should be able to do it several times
|
||||
successes = pool.block_until_target_successes()
|
||||
assert len(successes) == 10
|
||||
|
||||
|
||||
def test_wait_for_successes_out_of_values(join_worker_pool):
|
||||
"""
|
||||
Checks that if there weren't enough successful workers, `block_until_target_successes()`
|
||||
raises an exception when the value factory is exhausted.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 9),
|
||||
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 20),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = AllAtOnceFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=15)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
with pytest.raises(WorkerPool.OutOfValues):
|
||||
successes = pool.block_until_target_successes()
|
||||
t_end = time.monotonic()
|
||||
|
||||
# We have roughly 2 workers per thread, so it shouldn't take longer than 1.5s (max timeout) * 2
|
||||
assert t_end - t_start < 4
|
||||
|
||||
|
||||
def test_wait_for_successes_timed_out(join_worker_pool):
|
||||
"""
|
||||
Checks that if enough successful workers can't finish before the timeout, we get an exception.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=0, timeout_max=0.5), 9),
|
||||
(WorkerRule(timeout_min=1.5, timeout_max=2.5), 1),
|
||||
(WorkerRule(fails=True, timeout_min=1.5, timeout_max=2.5), 20),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = AllAtOnceFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=1, threadpool_size=30)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
with pytest.raises(WorkerPool.TimedOut):
|
||||
successes = pool.block_until_target_successes()
|
||||
t_end = time.monotonic()
|
||||
|
||||
# Even though timeout is 1, there are long-running workers which we can't interupt.
|
||||
assert t_end - t_start < 3
|
||||
|
||||
|
||||
def test_join(join_worker_pool):
|
||||
"""
|
||||
Test joining the pool.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 9),
|
||||
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 20),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = AllAtOnceFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=1, threadpool_size=30)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
pool.join()
|
||||
t_end = time.monotonic()
|
||||
|
||||
pool.join() # should work the second time too
|
||||
|
||||
# Even though timeout is 1, there are long-running workers which we can't interupt.
|
||||
assert t_end - t_start < 3
|
||||
|
||||
|
||||
class BatchFactory:
|
||||
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
self.batch_sizes = []
|
||||
|
||||
def __call__(self, successes):
|
||||
if successes == 10:
|
||||
return None
|
||||
batch_size = 10 - successes
|
||||
if len(self.values) >= batch_size:
|
||||
batch = self.values[:batch_size]
|
||||
self.batch_sizes.append(len(batch))
|
||||
self.values = self.values[batch_size:]
|
||||
return batch
|
||||
elif len(self.values) > 0:
|
||||
self.batch_sizes.append(len(self.values))
|
||||
return self.values
|
||||
self.values = None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def test_batched_value_generation(join_worker_pool):
|
||||
"""
|
||||
Tests a value factory that gives out value batches in portions.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 80),
|
||||
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 80),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = BatchFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
successes = pool.block_until_target_successes()
|
||||
pool.cancel()
|
||||
pool.join()
|
||||
t_end = time.monotonic()
|
||||
|
||||
assert len(successes) == 10
|
||||
|
||||
# Check that batch sizes in the factory were getting progressively smaller
|
||||
# as the number of successes grew.
|
||||
assert all(factory.batch_sizes[i] >= factory.batch_sizes[i+1]
|
||||
for i in range(len(factory.batch_sizes) - 1))
|
||||
|
||||
# Since we canceled the pool, no more workers will be started and we will finish faster
|
||||
assert t_end - t_start < 4
|
||||
|
||||
successes_copy = pool.get_successes()
|
||||
failures_copy = pool.get_failures()
|
||||
|
||||
assert all(value in successes_copy for value in successes)
|
||||
|
||||
|
||||
def test_cancel_waiting_workers(join_worker_pool):
|
||||
"""
|
||||
If we have a small pool and many workers, it is possible for workers to be enqueued
|
||||
one after another in one thread.
|
||||
We test that if we call `cancel()`, these enqueued workers are cancelled too.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[
|
||||
(WorkerRule(timeout_min=1, timeout_max=1), 100),
|
||||
],
|
||||
seed=123)
|
||||
|
||||
factory = AllAtOnceFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
|
||||
join_worker_pool(pool)
|
||||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
pool.block_until_target_successes()
|
||||
pool.cancel()
|
||||
pool.join()
|
||||
t_end = time.monotonic()
|
||||
|
||||
# We have 10 threads in the pool and 100 workers that are all enqueued at once at the start.
|
||||
# If we didn't check for the cancel condition, we would have to wait for 10 seconds.
|
||||
# We get 10 successes after 1s and cancel the workers,
|
||||
# but the next workers in each thread have already started, so we have to wait for another 1s.
|
||||
assert t_end - t_start < 2.5
|
||||
|
||||
|
||||
class BuggyFactory:
|
||||
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def __call__(self, successes):
|
||||
if self.values is not None:
|
||||
values = self.values
|
||||
self.values = None
|
||||
return values
|
||||
else:
|
||||
raise Exception("Buggy factory")
|
||||
|
||||
|
||||
def test_buggy_factory_raises_on_block(join_worker_pool):
|
||||
"""
|
||||
Tests that if there is an exception thrown in the value factory,
|
||||
it is caught in the first call to `block_until_target_successes()`.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[(WorkerRule(timeout_min=1, timeout_max=1), 100)],
|
||||
seed=123)
|
||||
|
||||
factory = BuggyFactory(list(outcomes))
|
||||
|
||||
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
|
||||
# So that we got enough successes for `block_until_target_successes()`.
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
|
||||
join_worker_pool(pool)
|
||||
|
||||
pool.start()
|
||||
time.sleep(2) # wait for the stagger timeout to finish
|
||||
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||
pool.block_until_target_successes()
|
||||
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
|
||||
pool.block_until_target_successes()
|
||||
pool.cancel()
|
||||
pool.join()
|
||||
|
||||
|
||||
def test_buggy_factory_raises_on_join(join_worker_pool):
|
||||
"""
|
||||
Tests that if there is an exception thrown in the value factory,
|
||||
it is caught in the first call to `join()`.
|
||||
"""
|
||||
|
||||
outcomes, worker = generate_workers(
|
||||
[(WorkerRule(timeout_min=1, timeout_max=1), 100)],
|
||||
seed=123)
|
||||
|
||||
factory = BuggyFactory(list(outcomes))
|
||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
|
||||
join_worker_pool(pool)
|
||||
|
||||
pool.start()
|
||||
pool.cancel()
|
||||
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||
pool.join()
|
||||
pool.join()
|
|
@ -241,7 +241,7 @@ def enacted_federated_policy(idle_federated_policy, federated_ursulas):
|
|||
# REST call happens here, as does population of TreasureMap.
|
||||
enacted_policy = idle_federated_policy.enact(network_middleware=network_middleware,
|
||||
handpicked_ursulas=federated_ursulas)
|
||||
enacted_policy.publishing_mutex.block_until_complete()
|
||||
enacted_policy.treasure_map_publisher.block_until_complete()
|
||||
|
||||
return enacted_policy
|
||||
|
||||
|
@ -278,7 +278,7 @@ def enacted_blockchain_policy(idle_blockchain_policy, blockchain_ursulas):
|
|||
# REST call happens here, as does population of TreasureMap.
|
||||
enacted_policy = idle_blockchain_policy.enact(network_middleware=network_middleware,
|
||||
handpicked_ursulas=list(blockchain_ursulas))
|
||||
enacted_policy.publishing_mutex.block_until_complete()
|
||||
enacted_policy.treasure_map_publisher.block_until_complete()
|
||||
return enacted_policy
|
||||
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ def test_alice_verifies_ursula_just_in_time(fleet_of_highperf_mocked_ursulas,
|
|||
_POLICY_PRESERVER.append(policy)
|
||||
|
||||
|
||||
# @pytest_twisted.inlineCallbacks # TODO: Why does this, in concert with yield policy.publishing_mutex.when_complete, hang?
|
||||
# @pytest_twisted.inlineCallbacks # TODO: Why does this, in concert with yield policy.treasure_map_publisher.when_complete, hang?
|
||||
def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
|
||||
highperf_mocked_alice,
|
||||
highperf_mocked_bob):
|
||||
|
@ -186,21 +186,21 @@ def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
|
|||
# defer.setDebugging(True)
|
||||
|
||||
# PART II: We block for a little while to ensure that the distribution is going well.
|
||||
nodes_that_have_the_map_when_we_unblock = policy.publishing_mutex.block_until_success_is_reasonably_likely()
|
||||
nodes_that_have_the_map_when_we_unblock = policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
|
||||
little_while_ended_at = datetime.now()
|
||||
|
||||
# The number of nodes having the map is at least the minimum to have unblocked.
|
||||
assert len(nodes_that_have_the_map_when_we_unblock) >= policy.publishing_mutex._block_until_this_many_are_complete
|
||||
assert len(nodes_that_have_the_map_when_we_unblock) >= policy.treasure_map_publisher._block_until_this_many_are_complete
|
||||
|
||||
# The number of nodes having the map is approximately the number you'd expect from full utilization of Alice's publication threadpool.
|
||||
# TODO: This line fails sometimes because the loop goes too fast.
|
||||
# assert len(nodes_that_have_the_map_when_we_unblock) == pytest.approx(policy.publishing_mutex._block_until_this_many_are_complete, .2)
|
||||
# assert len(nodes_that_have_the_map_when_we_unblock) == pytest.approx(policy.treasure_map_publisher._block_until_this_many_are_complete, .2)
|
||||
|
||||
# PART III: Having made proper assertions about the publication call and the first block, we allow the rest to
|
||||
# happen in the background and then ensure that each phase was timely.
|
||||
|
||||
# This will block until the distribution is complete.
|
||||
policy.publishing_mutex.block_until_complete()
|
||||
policy.treasure_map_publisher.block_until_complete()
|
||||
complete_distribution_time = datetime.now() - started
|
||||
partial_blocking_duration = little_while_ended_at - started
|
||||
# Before Treasure Island (1741), this process took about 3 minutes.
|
||||
|
@ -213,7 +213,7 @@ def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
|
|||
# But with debuggers and other processes running on laptops, we give a little leeway.
|
||||
|
||||
# We have the same number of successful responses as nodes we expected to have the map.
|
||||
assert len(policy.publishing_mutex.completed) == len(nodes_we_expect_to_have_the_map)
|
||||
assert len(policy.treasure_map_publisher.completed) == len(nodes_we_expect_to_have_the_map)
|
||||
nodes_that_got_the_map = sum(
|
||||
u._its_down_there_somewhere_let_me_take_another_look is True for u in nodes_we_expect_to_have_the_map)
|
||||
assert nodes_that_got_the_map == len(nodes_we_expect_to_have_the_map)
|
||||
|
|
Loading…
Reference in New Issue