mirror of https://github.com/nucypher/nucypher.git
Stop the threadpool when all the worker results are processed
parent
bdf8c6a7e1
commit
7b4ab2a412
|
@ -21,7 +21,6 @@ from threading import Thread, Event, Lock, Timer, get_ident
|
||||||
from typing import Callable, List, Any, Optional, Dict
|
from typing import Callable, List, Any, Optional, Dict
|
||||||
|
|
||||||
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
|
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
|
||||||
from twisted._threads import AlreadyQuit
|
|
||||||
from twisted.python.threadpool import ThreadPool
|
from twisted.python.threadpool import ThreadPool
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +123,8 @@ class WorkerPool:
|
||||||
self._target_value = SetOnce()
|
self._target_value = SetOnce()
|
||||||
self._unexpected_error = SetOnce()
|
self._unexpected_error = SetOnce()
|
||||||
self._results_lock = Lock()
|
self._results_lock = Lock()
|
||||||
self._stopped = False
|
self._threadpool_stop_lock = Lock()
|
||||||
|
self._threadpool_stopped = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
# TODO: check if already started?
|
# TODO: check if already started?
|
||||||
|
@ -139,29 +139,35 @@ class WorkerPool:
|
||||||
"""
|
"""
|
||||||
self._cancel_event.set()
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
def _stop_threadpool(self):
|
||||||
|
# This can be called from multiple threads
|
||||||
|
# (`join()` itself can be called from multiple threads,
|
||||||
|
# and we also attempt to stop the pool from the `_process_results()` thread).
|
||||||
|
with self._threadpool_stop_lock:
|
||||||
|
if not self._threadpool_stopped:
|
||||||
|
self._threadpool.stop()
|
||||||
|
self._threadpool_stopped = True
|
||||||
|
|
||||||
|
def _check_for_unexpected_error(self):
|
||||||
|
if self._unexpected_error.is_set():
|
||||||
|
e = self._unexpected_error.get()
|
||||||
|
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
"""
|
"""
|
||||||
Waits for all the threads to finish.
|
Waits for all the threads to finish.
|
||||||
Can be called several times.
|
Can be called several times.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._stopped:
|
|
||||||
return # or raise AlreadyStopped?
|
|
||||||
|
|
||||||
self._produce_values_thread.join()
|
self._produce_values_thread.join()
|
||||||
self._process_results_thread.join()
|
self._process_results_thread.join()
|
||||||
self._bail_on_timeout_thread.join()
|
self._bail_on_timeout_thread.join()
|
||||||
|
|
||||||
# protect from a possible race
|
# In most cases `_threadpool` will be stopped by the `_process_results()` thread.
|
||||||
try:
|
# But in case there's some unexpected bug in its code, we're making sure the pool is stopped
|
||||||
self._threadpool.stop()
|
# to avoid the whole process hanging.
|
||||||
except AlreadyQuit:
|
self._stop_threadpool()
|
||||||
pass
|
|
||||||
self._stopped = True
|
|
||||||
|
|
||||||
if self._unexpected_error.is_set():
|
self._check_for_unexpected_error()
|
||||||
e = self._unexpected_error.get()
|
|
||||||
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
|
|
||||||
|
|
||||||
def _sleep(self, timeout):
|
def _sleep(self, timeout):
|
||||||
"""
|
"""
|
||||||
|
@ -176,10 +182,7 @@ class WorkerPool:
|
||||||
Returns a dictionary of values matched to results.
|
Returns a dictionary of values matched to results.
|
||||||
Can be called several times.
|
Can be called several times.
|
||||||
"""
|
"""
|
||||||
if self._unexpected_error.is_set():
|
self._check_for_unexpected_error()
|
||||||
# So that we don't raise it again when join() is called
|
|
||||||
e = self._unexpected_error.get_and_clear()
|
|
||||||
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
|
|
||||||
|
|
||||||
result = self._target_value.get()
|
result = self._target_value.get()
|
||||||
if result == TIMEOUT_TRIGGERED:
|
if result == TIMEOUT_TRIGGERED:
|
||||||
|
@ -258,6 +261,8 @@ class WorkerPool:
|
||||||
self._target_value.set(PRODUCER_STOPPED)
|
self._target_value.set(PRODUCER_STOPPED)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self._stop_threadpool()
|
||||||
|
|
||||||
def _produce_values(self):
|
def _produce_values(self):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -300,7 +300,7 @@ class BuggyFactory:
|
||||||
raise Exception("Buggy factory")
|
raise Exception("Buggy factory")
|
||||||
|
|
||||||
|
|
||||||
def test_buggy_factory_raises_on_block(join_worker_pool):
|
def test_buggy_factory_raises_on_block():
|
||||||
"""
|
"""
|
||||||
Tests that if there is an exception thrown in the value factory,
|
Tests that if there is an exception thrown in the value factory,
|
||||||
it is caught in the first call to `block_until_target_successes()`.
|
it is caught in the first call to `block_until_target_successes()`.
|
||||||
|
@ -315,19 +315,21 @@ def test_buggy_factory_raises_on_block(join_worker_pool):
|
||||||
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
|
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
|
||||||
# So that we got enough successes for `block_until_target_successes()`.
|
# So that we got enough successes for `block_until_target_successes()`.
|
||||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
|
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
|
||||||
join_worker_pool(pool)
|
|
||||||
|
|
||||||
pool.start()
|
pool.start()
|
||||||
time.sleep(2) # wait for the stagger timeout to finish
|
time.sleep(2) # wait for the stagger timeout to finish
|
||||||
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||||
pool.block_until_target_successes()
|
pool.block_until_target_successes()
|
||||||
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
|
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
|
||||||
pool.block_until_target_successes()
|
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||||
|
pool.block_until_target_successes()
|
||||||
pool.cancel()
|
pool.cancel()
|
||||||
pool.join()
|
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
|
||||||
def test_buggy_factory_raises_on_join(join_worker_pool):
|
def test_buggy_factory_raises_on_join():
|
||||||
"""
|
"""
|
||||||
Tests that if there is an exception thrown in the value factory,
|
Tests that if there is an exception thrown in the value factory,
|
||||||
it is caught in the first call to `join()`.
|
it is caught in the first call to `join()`.
|
||||||
|
@ -339,10 +341,10 @@ def test_buggy_factory_raises_on_join(join_worker_pool):
|
||||||
|
|
||||||
factory = BuggyFactory(list(outcomes))
|
factory = BuggyFactory(list(outcomes))
|
||||||
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
|
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
|
||||||
join_worker_pool(pool)
|
|
||||||
|
|
||||||
pool.start()
|
pool.start()
|
||||||
pool.cancel()
|
pool.cancel()
|
||||||
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||||
pool.join()
|
pool.join()
|
||||||
pool.join()
|
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
|
||||||
|
pool.join()
|
||||||
|
|
Loading…
Reference in New Issue