Merge pull request #2557 from fjarri/self-destruct-worker-pool

Stop the threadpool when all the worker results are processed
pull/2569/head
KPrasch 2021-02-15 11:47:47 -08:00 committed by GitHub
commit 7c6094f951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 49 deletions

View File

@ -0,0 +1 @@
Prevent process hanging in the cases when the main thread finishes before the treasure map publisher

View File

@ -326,7 +326,7 @@ class Policy(ABC):
if len(accepted_arrangements) < self.n:
rejected_proposals = "\n".join(f"{address}: {exception}" for address, exception in failures.items())
rejected_proposals = "\n".join(f"{address}: {value}" for address, (type_, value, traceback) in failures.items())
self.log.debug(
"Could not find enough Ursulas to accept proposals.\n"

View File

@ -19,9 +19,9 @@ import time
from queue import Queue, Empty
from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict
import sys
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted._threads import AlreadyQuit
from twisted.python.threadpool import ThreadPool
@ -30,20 +30,28 @@ class Success:
self.value = value
self.result = result
class Failure:
def __init__(self, value, exception):
def __init__(self, value, exc_info):
self.value = value
self.exception = exception
self.exc_info = exc_info
class Cancelled(Exception):
pass
class SetOnce:
class FutureResult:
def __init__(self, value=None, exc_info=None):
self.value = value
self.exc_info = exc_info
class Future:
"""
A convenience wrapper for a value that can be set once (which can be waited on),
and cannot be overwritten (unless cleared).
A simplified future object. Can be set to some value (all further sets are ignored),
can be waited on.
"""
def __init__(self):
@ -51,25 +59,34 @@ class SetOnce:
self._set_event = Event()
self._value = None
def set(self, value):
def _set(self, value):
with self._lock:
if not self._set_event.is_set():
self._value = value
self._set_event.set()
def set(self, value):
self._set(FutureResult(value=value))
def set_exception(self):
exc_info = sys.exc_info()
self._set(FutureResult(exc_info=exc_info))
def is_set(self):
return self._set_event.is_set()
def get_and_clear(self):
with self._lock:
value = self._value
self._value = None
self._set_event.clear()
return value
def get(self):
self._set_event.wait()
return self._value
if self._value.exc_info is not None:
(exc_type, exc_value, exc_traceback) = self._value.exc_info
if exc_value is None:
exc_value = exc_type()
if exc_value.__traceback__ is not exc_traceback:
raise exc_value.with_traceback(exc_traceback)
raise exc_value
else:
return self._value.value
class WorkerPool:
@ -121,10 +138,11 @@ class WorkerPool:
self._cancel_event = Event()
self._result_queue = Queue()
self._target_value = SetOnce()
self._unexpected_error = SetOnce()
self._target_value = Future()
self._producer_error = Future()
self._results_lock = Lock()
self._stopped = False
self._threadpool_stop_lock = Lock()
self._threadpool_stopped = False
def start(self):
# TODO: check if already started?
@ -139,29 +157,36 @@ class WorkerPool:
"""
self._cancel_event.set()
def _stop_threadpool(self):
# This can be called from multiple threads
# (`join()` itself can be called from multiple threads,
# and we also attempt to stop the pool from the `_process_results()` thread).
with self._threadpool_stop_lock:
if not self._threadpool_stopped:
self._threadpool.stop()
self._threadpool_stopped = True
def _check_for_producer_error(self):
# Check for any unexpected exceptions in the producer thread
if self._producer_error.is_set():
# Will raise if Future was set with an exception
self._producer_error.get()
def join(self):
"""
Waits for all the threads to finish.
Can be called several times.
"""
if self._stopped:
return # or raise AlreadyStopped?
self._produce_values_thread.join()
self._process_results_thread.join()
self._bail_on_timeout_thread.join()
# protect from a possible race
try:
self._threadpool.stop()
except AlreadyQuit:
pass
self._stopped = True
# In most cases `_threadpool` will be stopped by the `_process_results()` thread.
# But in case there's some unexpected bug in its code, we're making sure the pool is stopped
# to avoid the whole process hanging.
self._stop_threadpool()
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
self._check_for_producer_error()
def _sleep(self, timeout):
"""
@ -176,10 +201,7 @@ class WorkerPool:
Returns a dictionary of values matched to results.
Can be called several times.
"""
if self._unexpected_error.is_set():
# So that we don't raise it again when join() is called
e = self._unexpected_error.get_and_clear()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
self._check_for_producer_error()
result = self._target_value.get()
if result == TIMEOUT_TRIGGERED:
@ -224,7 +246,7 @@ class WorkerPool:
except Cancelled as e:
self._result_queue.put(e)
except BaseException as e:
self._result_queue.put(Failure(value, str(e)))
self._result_queue.put(Failure(value, sys.exc_info()))
def _process_results(self):
"""
@ -251,13 +273,15 @@ class WorkerPool:
self._target_value.set(self.get_successes())
if isinstance(result, Failure):
with self._results_lock:
self._failures[result.value] = result.exception
self._failures[result.value] = result.exc_info
if producer_stopped and self._finished_tasks == self._started_tasks:
self.cancel() # to cancel the timeout thread
self._target_value.set(PRODUCER_STOPPED)
break
self._stop_threadpool()
def _produce_values(self):
while True:
try:
@ -279,8 +303,8 @@ class WorkerPool:
except Cancelled:
break
except BaseException as e:
self._unexpected_error.set(e)
except BaseException:
self._producer_error.set_exception()
self.cancel()
break

View File

@ -300,7 +300,7 @@ class BuggyFactory:
raise Exception("Buggy factory")
def test_buggy_factory_raises_on_block(join_worker_pool):
def test_buggy_factory_raises_on_block():
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `block_until_target_successes()`.
@ -315,19 +315,21 @@ def test_buggy_factory_raises_on_block(join_worker_pool):
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
# So that we got enough successes for `block_until_target_successes()`.
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
join_worker_pool(pool)
pool.start()
time.sleep(2) # wait for the stagger timeout to finish
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.block_until_target_successes()
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
pool.block_until_target_successes()
with pytest.raises(Exception, match="Buggy factory"):
pool.block_until_target_successes()
pool.cancel()
pool.join()
with pytest.raises(Exception, match="Buggy factory"):
pool.join()
def test_buggy_factory_raises_on_join(join_worker_pool):
def test_buggy_factory_raises_on_join():
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `join()`.
@ -339,10 +341,10 @@ def test_buggy_factory_raises_on_join(join_worker_pool):
factory = BuggyFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
join_worker_pool(pool)
pool.start()
pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.join()
with pytest.raises(Exception, match="Buggy factory"):
pool.join()
pool.join()