Propagate the actual exception from the producer thread instead of serializing it into string

pull/2557/head
Bogdan Opanchuk 2021-02-09 16:01:58 -08:00
parent 359ad45a2f
commit 09c476e2da
2 changed files with 41 additions and 16 deletions

View File

@ -19,6 +19,7 @@ import time
from queue import Queue, Empty
from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict
import sys
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted.python.threadpool import ThreadPool
@ -40,6 +41,13 @@ class Cancelled(Exception):
pass
class FutureResult:
def __init__(self, value=None, exc_info=None):
self.value = value
self.exc_info = exc_info
class Future:
"""
A simplified future object. Can be set to some value (all further sets are ignored),
@ -51,18 +59,34 @@ class Future:
self._set_event = Event()
self._value = None
def set(self, value):
def _set(self, value):
with self._lock:
if not self._set_event.is_set():
self._value = value
self._set_event.set()
def set(self, value):
self._set(FutureResult(value=value))
def set_exception(self):
exc_info = sys.exc_info()
self._set(FutureResult(exc_info=exc_info))
def is_set(self):
return self._set_event.is_set()
def get(self):
self._set_event.wait()
return self._value
if self._value.exc_info is not None:
(exc_type, exc_value, exc_traceback) = self._value.exc_info
if exc_value is None:
exc_value = exc_type()
if exc_value.__traceback__ is not exc_traceback:
raise exc_value.with_traceback(exc_traceback)
raise exc_value
else:
return self._value.value
class WorkerPool:
@ -115,7 +139,7 @@ class WorkerPool:
self._cancel_event = Event()
self._result_queue = Queue()
self._target_value = Future()
self._unexpected_error = Future()
self._producer_error = Future()
self._results_lock = Lock()
self._threadpool_stop_lock = Lock()
self._threadpool_stopped = False
@ -142,10 +166,11 @@ class WorkerPool:
self._threadpool.stop()
self._threadpool_stopped = True
def _check_for_unexpected_error(self):
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
def _check_for_producer_error(self):
# Check for any unexpected exceptions in the producer thread
if self._producer_error.is_set():
# Will raise if Future was set with an exception
self._producer_error.get()
def join(self):
"""
@ -161,7 +186,7 @@ class WorkerPool:
# to avoid the whole process hanging.
self._stop_threadpool()
self._check_for_unexpected_error()
self._check_for_producer_error()
def _sleep(self, timeout):
"""
@ -176,7 +201,7 @@ class WorkerPool:
Returns a dictionary of values matched to results.
Can be called several times.
"""
self._check_for_unexpected_error()
self._check_for_producer_error()
result = self._target_value.get()
if result == TIMEOUT_TRIGGERED:
@ -278,8 +303,8 @@ class WorkerPool:
except Cancelled:
break
except BaseException as e:
self._unexpected_error.set(e)
except BaseException:
self._producer_error.set_exception()
self.cancel()
break

View File

@ -318,14 +318,14 @@ def test_buggy_factory_raises_on_block():
pool.start()
time.sleep(2) # wait for the stagger timeout to finish
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.block_until_target_successes()
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.block_until_target_successes()
pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.join()
@ -344,7 +344,7 @@ def test_buggy_factory_raises_on_join():
pool.start()
pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.join()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
with pytest.raises(Exception, match="Buggy factory"):
pool.join()