Merge pull request #2772 from piotr-roslaniec/handle-missing-ursulas

Porter - return meaningful error if there are not enough Ursulas
pull/2786/head
Derek Pierre 2021-08-26 09:19:53 -04:00 committed by GitHub
commit bcd6071e3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 41 deletions

View File

@ -0,0 +1 @@
Update WorkerPool error messages returned by Porter API.

View File

@ -35,6 +35,7 @@ from nucypher.control.interfaces import ControlInterface
from nucypher.control.specifications.exceptions import SpecificationError
from nucypher.exceptions import DevelopmentInstallationRequired
from nucypher.network.resources import get_static_resources
from nucypher.utilities.concurrency import WorkerPool, WorkerPoolException
from nucypher.utilities.logging import Logger, GlobalLoggerSettings
@ -350,6 +351,32 @@ class WebController(InterfaceControlServer):
#
# Unhandled Server Errors
#
except WorkerPoolException as e:
# special case since WorkerPoolException contain stack traces - not ideal for returning from REST endpoints
__exception_code = 500
if self.crash_on_error:
raise
if isinstance(e, WorkerPool.TimedOut):
message_prefix = f"Execution timed out after {e.timeout}s"
else:
message_prefix = f"Execution failed - no more values to try"
# get random failure for context
if e.failures:
value = list(e.failures)[0]
_, exception, _ = e.failures[value]
msg = f"{message_prefix} ({len(e.failures)} concurrent failures recorded); " \
f"for example, for {value}: {exception}"
else:
msg = message_prefix
return self.emitter.exception(
e=RuntimeError(msg),
log_level='warn',
response_code=__exception_code,
error_message=WebController._captured_status_codes[__exception_code])
except Exception as e:
__exception_code = 500
if self.crash_on_error:

View File

@ -254,7 +254,8 @@ class WebEmitter:
log_level: str = 'info',
response_code: int = 500):
message = f"{self} [{str(response_code)} - {error_message}] | ERROR: {str(e) or type(e).__name__}"
exception = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__
message = f"{self} [{str(response_code)} - {error_message}] | ERROR: {exception}"
logger = getattr(self.log, log_level)
# See #724 / 2156
message_cleaned_for_logger = message.replace("{", "<^<").replace("}", ">^>")

View File

@ -14,7 +14,7 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Iterable, Optional, List
from typing import Iterable, List, Optional
from eth_typing import ChecksumAddress

View File

@ -16,12 +16,11 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import io
import time
import traceback
from queue import Queue, Empty
from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict
import sys
import traceback
from queue import Queue
from threading import Thread, Event, Lock
from typing import Callable, List, Any, Optional, Dict
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted.python.threadpool import ThreadPool
@ -91,29 +90,37 @@ class Future:
return self._value.value
def format_failures(failures: Dict) -> str:
"""
Performs some basic formatting of the WorkerPool failures,
providing some context of why TimedOut/OutOfValues occurred.
"""
class WorkerPoolException(Exception):
"""Generalized exception class for WorkerPool failures."""
def __init__(self, message_prefix: str, failures: Dict):
self.failures = failures
if failures:
# Using one random failure to print the traceback.
# Most probably they're all the same anyway.
value = list(failures)[0]
_exception_cls, exception, tb = failures[value]
# craft message
msg = message_prefix
if self.failures:
# Using one random failure
# Most probably they're all the same anyway.
value = list(self.failures)[0]
_, exception, tb = self.failures[value]
f = io.StringIO()
traceback.print_tb(tb, file=f)
traceback_str = f.getvalue()
msg = (f"{message_prefix} ({len(self.failures)} failures recorded); "
f"for example, for {value}:\n"
f"{traceback_str}\n"
f"{exception}")
super().__init__(msg)
f = io.StringIO()
traceback.print_tb(tb, file=f)
traceback_str = f.getvalue()
def get_tracebacks(self) -> Dict[Any, str]:
"""Returns values and associated tracebacks of execution failures."""
exc_tracebacks = {}
for value, exc_info in self.failures.items():
_, exception, tb = exc_info
f = io.StringIO()
traceback.print_tb(tb, file=f)
exc_tracebacks[value] = f"{f.getvalue()}\n{exception}"
return (f"{len(failures)} total failures recorded;\n"
f"for example, for the value {value}:\n"
f"{traceback_str}\n"
f"{exception}")
else:
return "0 total failures recorded"
return exc_tracebacks
class WorkerPool:
@ -124,11 +131,18 @@ class WorkerPool:
(a worker returning something without throwing an exception).
"""
class TimedOut(Exception):
"Raised if waiting for the target number of successes timed out."
class TimedOut(WorkerPoolException):
"""Raised if waiting for the target number of successes timed out."""
def __init__(self, timeout: float, *args, **kwargs):
self.timeout = timeout
super().__init__(message_prefix=f"Execution timed out after {timeout}s",
*args, **kwargs)
class OutOfValues(Exception):
"Raised if the value factory is out of values, but the target number was not reached."
class OutOfValues(WorkerPoolException):
"""Raised if the value factory is out of values, but the target number was not reached."""
def __init__(self, *args, **kwargs):
super().__init__(message_prefix="Execution stopped before completion - not enough available values",
*args, **kwargs)
def __init__(self,
worker: Callable[[Any], Any],
@ -232,9 +246,9 @@ class WorkerPool:
result = self._target_value.get()
if result == TIMEOUT_TRIGGERED:
raise self.TimedOut(format_failures(self.get_failures()))
raise self.TimedOut(timeout=self._timeout, failures=self.get_failures())
elif result == PRODUCER_STOPPED:
raise self.OutOfValues(format_failures(self.get_failures()))
raise self.OutOfValues(failures=self.get_failures())
return result
def get_failures(self) -> Dict:

View File

@ -14,9 +14,9 @@
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import List, Optional, Sequence, NamedTuple
from typing import List, NamedTuple, Optional, Sequence
from constant_sorrow.constants import NO_CONTROL_PROTOCOL, NO_BLOCKCHAIN_CONNECTION
from constant_sorrow.constants import NO_BLOCKCHAIN_CONNECTION, NO_CONTROL_PROTOCOL
from eth_typing import ChecksumAddress
from flask import request, Response
@ -24,7 +24,7 @@ from nucypher.blockchain.eth.agents import ContractAgency, StakingEscrowAgent
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import BaseContractRegistry, InMemoryContractRegistry
from nucypher.characters.lawful import Ursula
from nucypher.control.controllers import WebController, JSONRPCController
from nucypher.control.controllers import JSONRPCController, WebController
from nucypher.crypto.powers import DecryptingPower
from nucypher.crypto.umbral_adapter import PublicKey
from nucypher.network.nodes import Learner

View File

@ -17,7 +17,7 @@
import random
import time
from typing import Iterable, Tuple, List, Callable
from typing import Iterable, Tuple
import pytest
@ -157,11 +157,21 @@ def test_wait_for_successes_out_of_values(join_worker_pool):
message = str(exc_info.value)
assert "Execution stopped before completion - not enough available values" in message
# We had 20 workers set up to fail
assert "20 total failures recorded" in message
num_expected_failures = 20
assert f"{num_expected_failures} failures recorded" in message
# check tracebacks
tracebacks = exc_info.value.get_tracebacks()
assert len(tracebacks) == num_expected_failures
for value, traceback in tracebacks.items():
assert 'raise Exception(f"Worker for {value} failed")' in traceback
assert f'Worker for {value} failed' in traceback
# This will be the last line in the displayed traceback;
# That's where the worker actually failed.
# That's where the worker actually failed. (Worker for {value} failed)
assert 'raise Exception(f"Worker for {value} failed")' in message
@ -179,7 +189,8 @@ def test_wait_for_successes_timed_out(join_worker_pool):
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=1, threadpool_size=30)
timeout = 1
pool = WorkerPool(worker, factory, target_successes=10, timeout=timeout, threadpool_size=30)
join_worker_pool(pool)
t_start = time.monotonic()
@ -194,7 +205,7 @@ def test_wait_for_successes_timed_out(join_worker_pool):
message = str(exc_info.value)
# None of the workers actually failed, they just timed out
assert "0 total failures recorded" in message
assert f"Execution timed out after {timeout}s" == message
def test_join(join_worker_pool):