Stop the threadpool when all the worker results are processed

pull/2557/head
Bogdan Opanchuk 2021-02-07 14:19:23 -08:00
parent bdf8c6a7e1
commit 7b4ab2a412
2 changed files with 33 additions and 26 deletions

View File

@ -21,7 +21,6 @@ from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted._threads import AlreadyQuit
from twisted.python.threadpool import ThreadPool
@ -124,7 +123,8 @@ class WorkerPool:
self._target_value = SetOnce()
self._unexpected_error = SetOnce()
self._results_lock = Lock()
self._stopped = False
self._threadpool_stop_lock = Lock()
self._threadpool_stopped = False
def start(self):
# TODO: check if already started?
@ -139,29 +139,35 @@ class WorkerPool:
"""
self._cancel_event.set()
def _stop_threadpool(self):
# This can be called from multiple threads
# (`join()` itself can be called from multiple threads,
# and we also attempt to stop the pool from the `_process_results()` thread).
with self._threadpool_stop_lock:
if not self._threadpool_stopped:
self._threadpool.stop()
self._threadpool_stopped = True
def _check_for_unexpected_error(self):
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
def join(self):
"""
Waits for all the threads to finish.
Can be called several times.
"""
if self._stopped:
return # or raise AlreadyStopped?
self._produce_values_thread.join()
self._process_results_thread.join()
self._bail_on_timeout_thread.join()
# protect from a possible race
try:
self._threadpool.stop()
except AlreadyQuit:
pass
self._stopped = True
# In most cases `_threadpool` will be stopped by the `_process_results()` thread.
# But in case there's some unexpected bug in its code, we're making sure the pool is stopped
# to avoid the whole process hanging.
self._stop_threadpool()
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
self._check_for_unexpected_error()
def _sleep(self, timeout):
"""
@ -176,10 +182,7 @@ class WorkerPool:
Returns a dictionary of values matched to results.
Can be called several times.
"""
if self._unexpected_error.is_set():
# So that we don't raise it again when join() is called
e = self._unexpected_error.get_and_clear()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
self._check_for_unexpected_error()
result = self._target_value.get()
if result == TIMEOUT_TRIGGERED:
@ -258,6 +261,8 @@ class WorkerPool:
self._target_value.set(PRODUCER_STOPPED)
break
self._stop_threadpool()
def _produce_values(self):
while True:
try:

View File

@ -300,7 +300,7 @@ class BuggyFactory:
raise Exception("Buggy factory")
def test_buggy_factory_raises_on_block(join_worker_pool):
def test_buggy_factory_raises_on_block():
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `block_until_target_successes()`.
@ -315,19 +315,21 @@ def test_buggy_factory_raises_on_block(join_worker_pool):
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
# So that we got enough successes for `block_until_target_successes()`.
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
join_worker_pool(pool)
pool.start()
time.sleep(2) # wait for the stagger timeout to finish
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.block_until_target_successes()
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
pool.block_until_target_successes()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.block_until_target_successes()
pool.cancel()
pool.join()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()
def test_buggy_factory_raises_on_join(join_worker_pool):
def test_buggy_factory_raises_on_join():
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `join()`.
@ -339,10 +341,10 @@ def test_buggy_factory_raises_on_join(join_worker_pool):
factory = BuggyFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
join_worker_pool(pool)
pool.start()
pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()
pool.join()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()