Stop the threadpool when all the worker results are processed

pull/2557/head
Bogdan Opanchuk 2021-02-07 14:19:23 -08:00
parent bdf8c6a7e1
commit 7b4ab2a412
2 changed files with 33 additions and 26 deletions

View File

@ -21,7 +21,6 @@ from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict from typing import Callable, List, Any, Optional, Dict
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted._threads import AlreadyQuit
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
@ -124,7 +123,8 @@ class WorkerPool:
self._target_value = SetOnce() self._target_value = SetOnce()
self._unexpected_error = SetOnce() self._unexpected_error = SetOnce()
self._results_lock = Lock() self._results_lock = Lock()
self._stopped = False self._threadpool_stop_lock = Lock()
self._threadpool_stopped = False
def start(self): def start(self):
# TODO: check if already started? # TODO: check if already started?
@ -139,29 +139,35 @@ class WorkerPool:
""" """
self._cancel_event.set() self._cancel_event.set()
def _stop_threadpool(self):
# This can be called from multiple threads
# (`join()` itself can be called from multiple threads,
# and we also attempt to stop the pool from the `_process_results()` thread).
with self._threadpool_stop_lock:
if not self._threadpool_stopped:
self._threadpool.stop()
self._threadpool_stopped = True
def _check_for_unexpected_error(self):
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
def join(self): def join(self):
""" """
Waits for all the threads to finish. Waits for all the threads to finish.
Can be called several times. Can be called several times.
""" """
if self._stopped:
return # or raise AlreadyStopped?
self._produce_values_thread.join() self._produce_values_thread.join()
self._process_results_thread.join() self._process_results_thread.join()
self._bail_on_timeout_thread.join() self._bail_on_timeout_thread.join()
# protect from a possible race # In most cases `_threadpool` will be stopped by the `_process_results()` thread.
try: # But in case there's some unexpected bug in its code, we're making sure the pool is stopped
self._threadpool.stop() # to avoid the whole process hanging.
except AlreadyQuit: self._stop_threadpool()
pass
self._stopped = True
if self._unexpected_error.is_set(): self._check_for_unexpected_error()
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
def _sleep(self, timeout): def _sleep(self, timeout):
""" """
@ -176,10 +182,7 @@ class WorkerPool:
Returns a dictionary of values matched to results. Returns a dictionary of values matched to results.
Can be called several times. Can be called several times.
""" """
if self._unexpected_error.is_set(): self._check_for_unexpected_error()
# So that we don't raise it again when join() is called
e = self._unexpected_error.get_and_clear()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
result = self._target_value.get() result = self._target_value.get()
if result == TIMEOUT_TRIGGERED: if result == TIMEOUT_TRIGGERED:
@ -258,6 +261,8 @@ class WorkerPool:
self._target_value.set(PRODUCER_STOPPED) self._target_value.set(PRODUCER_STOPPED)
break break
self._stop_threadpool()
def _produce_values(self): def _produce_values(self):
while True: while True:
try: try:

View File

@ -300,7 +300,7 @@ class BuggyFactory:
raise Exception("Buggy factory") raise Exception("Buggy factory")
def test_buggy_factory_raises_on_block(join_worker_pool): def test_buggy_factory_raises_on_block():
""" """
Tests that if there is an exception thrown in the value factory, Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `block_until_target_successes()`. it is caught in the first call to `block_until_target_successes()`.
@ -315,19 +315,21 @@ def test_buggy_factory_raises_on_block(join_worker_pool):
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s, # Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
# So that we got enough successes for `block_until_target_successes()`. # So that we got enough successes for `block_until_target_successes()`.
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5) pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
join_worker_pool(pool)
pool.start() pool.start()
time.sleep(2) # wait for the stagger timeout to finish time.sleep(2) # wait for the stagger timeout to finish
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"): with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.block_until_target_successes() pool.block_until_target_successes()
# Further calls to `block_until_target_successes()` or `join()` don't throw the error. # Further calls to `block_until_target_successes()` or `join()` don't throw the error.
pool.block_until_target_successes() with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.block_until_target_successes()
pool.cancel() pool.cancel()
pool.join()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()
def test_buggy_factory_raises_on_join(join_worker_pool): def test_buggy_factory_raises_on_join():
""" """
Tests that if there is an exception thrown in the value factory, Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `join()`. it is caught in the first call to `join()`.
@ -339,10 +341,10 @@ def test_buggy_factory_raises_on_join(join_worker_pool):
factory = BuggyFactory(list(outcomes)) factory = BuggyFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10) pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
join_worker_pool(pool)
pool.start() pool.start()
pool.cancel() pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"): with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join() pool.join()
pool.join() with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()