mirror of https://github.com/nucypher/nucypher.git
Merge pull request #2744 from fjarri/worker-pool-errors
More informative `WorkerPool` errorspull/2748/head
commit
3a299f6893
|
@ -0,0 +1 @@
|
|||
Added a more informative error message for ``WorkerPool`` exceptions.
|
|
@ -15,8 +15,9 @@ You should have received a copy of the GNU Affero General Public License
|
|||
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import io
|
||||
import time
|
||||
import traceback
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event, Lock, Timer, get_ident
|
||||
from typing import Callable, List, Any, Optional, Dict
|
||||
|
@ -90,6 +91,31 @@ class Future:
|
|||
return self._value.value
|
||||
|
||||
|
||||
def format_failures(failures: Dict) -> str:
|
||||
"""
|
||||
Performs some basic formatting of the WorkerPool failures,
|
||||
providing some context of why TimedOut/OutOfValues occurred.
|
||||
"""
|
||||
|
||||
if failures:
|
||||
# Using one random failure to print the traceback.
|
||||
# Most probably they're all the same anyway.
|
||||
value = list(failures)[0]
|
||||
_exception_cls, exception, tb = failures[value]
|
||||
|
||||
f = io.StringIO()
|
||||
traceback.print_tb(tb, file=f)
|
||||
traceback_str = f.getvalue()
|
||||
|
||||
return (f"{len(failures)} total failures recorded;\n"
|
||||
f"for example, for the value {value}:\n"
|
||||
f"{traceback_str}\n"
|
||||
f"{exception}")
|
||||
|
||||
else:
|
||||
return "0 total failures recorded"
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""
|
||||
A generalized class that can start multiple workers in a thread pool with values
|
||||
|
@ -206,9 +232,9 @@ class WorkerPool:
|
|||
|
||||
result = self._target_value.get()
|
||||
if result == TIMEOUT_TRIGGERED:
|
||||
raise self.TimedOut()
|
||||
raise self.TimedOut(format_failures(self.get_failures()))
|
||||
elif result == PRODUCER_STOPPED:
|
||||
raise self.OutOfValues()
|
||||
raise self.OutOfValues(format_failures(self.get_failures()))
|
||||
return result
|
||||
|
||||
def get_failures(self) -> Dict:
|
||||
|
|
|
@ -131,13 +131,22 @@ def test_wait_for_successes_out_of_values(join_worker_pool):
|
|||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
with pytest.raises(WorkerPool.OutOfValues):
|
||||
with pytest.raises(WorkerPool.OutOfValues) as exc_info:
|
||||
successes = pool.block_until_target_successes()
|
||||
t_end = time.monotonic()
|
||||
|
||||
# We have roughly 2 workers per thread, so it shouldn't take longer than 1.5s (max timeout) * 2
|
||||
assert t_end - t_start < 4
|
||||
|
||||
message = str(exc_info.value)
|
||||
|
||||
# We had 20 workers set up to fail
|
||||
assert "20 total failures recorded" in message
|
||||
|
||||
# This will be the last line in the displayed traceback;
|
||||
# That's where the worker actually failed.
|
||||
assert 'raise Exception(f"Worker for {value} failed")' in message
|
||||
|
||||
|
||||
def test_wait_for_successes_timed_out(join_worker_pool):
|
||||
"""
|
||||
|
@ -158,13 +167,18 @@ def test_wait_for_successes_timed_out(join_worker_pool):
|
|||
|
||||
t_start = time.monotonic()
|
||||
pool.start()
|
||||
with pytest.raises(WorkerPool.TimedOut):
|
||||
with pytest.raises(WorkerPool.TimedOut) as exc_info:
|
||||
successes = pool.block_until_target_successes()
|
||||
t_end = time.monotonic()
|
||||
|
||||
# Even though timeout is 1, there are long-running workers which we can't interupt.
|
||||
assert t_end - t_start < 3
|
||||
|
||||
message = str(exc_info.value)
|
||||
|
||||
# None of the workers actually failed, they just timed out
|
||||
assert "0 total failures recorded" in message
|
||||
|
||||
|
||||
def test_join(join_worker_pool):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue