Merge pull request #2908 from derekpierre/optimize-jit-certs

Optimizations for SSL, Porter etc.
pull/2927/head
Derek Pierre 2022-04-22 13:41:23 -04:00 committed by GitHub
commit ea7d8b2b4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 83 deletions

View File

@ -514,6 +514,7 @@ Some common returned status codes you may encounter are:
- ``400 BAD REQUEST`` -- The server cannot or will not process the request due to something that is perceived to
be a client error (e.g., malformed request syntax, invalid request message framing, or deceptive request routing).
- ``401 UNAUTHORIZED`` -- Authentication is required and the request has failed to provide valid authentication credentials.
- ``404 NOT FOUND`` -- Request could not be completed because requested resources could not be found.
- ``500 INTERNAL SERVER ERROR`` -- The server encountered an unexpected condition that prevented it from
fulfilling the request.

View File

@ -0,0 +1,5 @@
SSL Certificate fetching and Porter optimizations
- Middleware should try cached SSL certification for a node first, and then if the requests fails, fetch the node's up-to-date SSL cert
- Short-circuit WorkerPool background execution once sufficient successful executions occur
- Don't limit WorkerPool size; this has consequences when smaller samples of ursulas are performed; allow threadpool to be flexible by using default min/max
- Return more comprehensive error information for failed WorkerPool execution

View File

@ -29,6 +29,7 @@ from hendrix.deploy.base import HendrixDeploy
from hendrix.deploy.tls import HendrixDeployTLS
from twisted.internet import reactor, stdio
from nucypher.cli.processes import JSONRPCLineReceiver
from nucypher.config.constants import MAX_UPLOAD_CONTENT_LENGTH
from nucypher.control.emitters import StdoutEmitter, JSONRPCStdoutEmitter, WebEmitter
from nucypher.control.interfaces import ControlInterface
@ -250,6 +251,7 @@ class WebController(InterfaceControlServer):
_captured_status_codes = {200: 'OK',
400: 'BAD REQUEST',
404: 'NOT FOUND',
500: 'INTERNAL SERVER ERROR'}
def test_client(self):
@ -302,8 +304,23 @@ class WebController(InterfaceControlServer):
def __call__(self, *args, **kwargs):
return self.handle_request(*args, **kwargs)
def handle_request(self, method_name, control_request, *args, **kwargs) -> Response:
@staticmethod
def json_response_from_worker_pool_exception(exception):
json_response = {
'failure_message': str(exception)
}
if exception.failures:
failures = []
for value, exc_info in exception.failures.items():
failures.append({
'value': value,
'error': str(exc_info[1])
})
json_response['failures'] = failures
return json_response
def handle_request(self, method_name, control_request, *args, **kwargs) -> Response:
_400_exceptions = (SpecificationError,
TypeError,
JSONDecodeError,
@ -336,47 +353,26 @@ class WebController(InterfaceControlServer):
error_message=WebController._captured_status_codes[__exception_code])
#
# Server Errors
# Execution Errors
#
except SpecificationError as e:
__exception_code = 500
except WorkerPoolException as e:
# special case since WorkerPoolException contains multiple stack traces
# - not ideal for returning from REST endpoints
__exception_code = 404
if self.crash_on_error:
raise
return self.emitter.exception(
e=e,
log_level='critical',
json_response_from_exception = self.json_response_from_worker_pool_exception(e)
return self.emitter.exception_with_response(
json_error_response=json_response_from_exception,
e=RuntimeError(json_response_from_exception['failure_message']),
error_message=WebController._captured_status_codes[__exception_code],
response_code=__exception_code,
error_message=WebController._captured_status_codes[__exception_code])
log_level='warn')
#
# Unhandled Server Errors
#
except WorkerPoolException as e:
# special case since WorkerPoolException contain stack traces - not ideal for returning from REST endpoints
__exception_code = 500
if self.crash_on_error:
raise
if isinstance(e, WorkerPool.TimedOut):
message_prefix = f"Execution timed out after {e.timeout}s"
else:
message_prefix = f"Execution failed - no more values to try"
# get random failure for context
if e.failures:
value = list(e.failures)[0]
_, exception, _ = e.failures[value]
msg = f"{message_prefix} ({len(e.failures)} concurrent failures recorded); " \
f"for example, for {value}: {exception}"
else:
msg = message_prefix
return self.emitter.exception(
e=RuntimeError(msg),
log_level='warn',
response_code=__exception_code,
error_message=WebController._captured_status_codes[__exception_code])
except Exception as e:
__exception_code = 500
if self.crash_on_error:
@ -392,4 +388,4 @@ class WebController(InterfaceControlServer):
#
else:
self.log.debug(f"{method_name} [200 - OK]")
return self.emitter.respond(response=response)
return self.emitter.respond(json_response=response)

View File

@ -243,6 +243,13 @@ class WebEmitter:
self.log = Logger('web-emitter')
def _log_exception(self, e, error_message, log_level, response_code):
exception = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__
message = f"{self} [{str(response_code)} - {error_message}] | ERROR: {exception}"
logger = getattr(self.log, log_level)
message_cleaned_for_logger = Logger.escape_format_string(message)
logger(message_cleaned_for_logger)
@staticmethod
def assemble_response(response: dict) -> dict:
response_data = {'result': response,
@ -255,25 +262,35 @@ class WebEmitter:
log_level: str = 'info',
response_code: int = 500):
exception = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__
message = f"{self} [{str(response_code)} - {error_message}] | ERROR: {exception}"
logger = getattr(self.log, log_level)
# See #724 / 2156
message_cleaned_for_logger = message.replace("{", "<^<").replace("}", ">^>")
logger(message_cleaned_for_logger)
self._log_exception(e, error_message, log_level, response_code)
if self.crash_on_error:
raise e
response_message = str(e) or type(e).__name__
return self.sink(response_message, status=response_code)
def respond(self, response) -> Response:
assembled_response = self.assemble_response(response=response)
def exception_with_response(self,
json_error_response,
e,
error_message: str,
response_code: int,
log_level: str = 'info'):
self._log_exception(e, error_message, log_level, response_code)
if self.crash_on_error:
raise e
assembled_response = self.assemble_response(response=json_error_response)
serialized_response = WebEmitter.transport_serializer(assembled_response)
# ---------- HTTP OUTPUT
response = self.sink(response=serialized_response, status=HTTPStatus.OK, content_type="application/javascript")
return response
json_response = self.sink(response=serialized_response, status=response_code, content_type="application/json")
return json_response
def respond(self, json_response) -> Response:
assembled_response = self.assemble_response(response=json_response)
serialized_response = WebEmitter.transport_serializer(assembled_response)
json_response = self.sink(response=serialized_response, status=HTTPStatus.OK, content_type="application/json")
return json_response
def get_stream(self, *args, **kwargs):
return null_stream()

View File

@ -30,6 +30,7 @@ from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import Certificate
from nucypher_core import MetadataRequest, FleetStateChecksum, NodeMetadata
from requests.exceptions import SSLError
from nucypher.blockchain.eth.registry import BaseContractRegistry
from nucypher.config.storages import ForgetfulNodeStorage
@ -148,15 +149,13 @@ class NucypherMiddlewareClient:
endpoint = f"https://{host}:{port}/{path}"
method = getattr(http_client, method_name)
# Fetch SSL certificate
try:
certificate, filepath = self.get_certificate(host=host, port=port)
except NodeSeemsToBeDown as e:
raise RestMiddleware.Unreachable(message=f'Node {node_or_sprout} {host}:{port} is unreachable: {e}')
# Send request
response = self.invoke_method(method, endpoint, verify=filepath, *args, **kwargs)
response = self._execute_method(node_or_sprout,
host,
port,
method,
endpoint,
*args,
**kwargs)
# Handle response
cleaned_response = self.response_cleaner(response)
if cleaned_response.status_code >= 300:
@ -183,6 +182,40 @@ class NucypherMiddlewareClient:
return method_wrapper
def _execute_method(self,
node_or_sprout,
host,
port,
method,
endpoint,
*args, **kwargs):
# Use existing cached SSL certificate or fetch fresh copy and retry
cached_cert_filepath = Path(self.storage.generate_certificate_filepath(host=host, port=port))
if cached_cert_filepath.exists():
# already cached try it
try:
# Send request
response = self.invoke_method(method, endpoint, verify=cached_cert_filepath,
*args, **kwargs)
# successful use of cached certificate
return response
except SSLError as e:
# ignore this exception - probably means that our cached cert may not be up-to-date.
SSL_LOGGER.debug(f"Cached cert for {host}:{port} is invalid {e}")
# Fetch fresh copy of SSL certificate
try:
certificate, filepath = self.get_certificate(host=host, port=port)
except NodeSeemsToBeDown as e:
raise RestMiddleware.Unreachable(
message=f'Node {node_or_sprout} {host}:{port} is unreachable: {e}')
# Send request
response = self.invoke_method(method, endpoint, verify=filepath,
*args, **kwargs)
return response
def node_selector(self, node):
return node.rest_url(), self.library

View File

@ -138,8 +138,7 @@ class Policy(ABC):
value_factory=value_factory,
target_successes=self.shares,
timeout=timeout,
stagger_timeout=1,
threadpool_size=self.shares
stagger_timeout=1
)
worker_pool.start()
try:

View File

@ -98,17 +98,7 @@ class WorkerPoolException(Exception):
# craft message
msg = message_prefix
if self.failures:
# Using one random failure
# Most probably they're all the same anyway.
value = list(self.failures)[0]
_, exception, tb = self.failures[value]
f = io.StringIO()
traceback.print_tb(tb, file=f)
traceback_str = f.getvalue()
msg = (f"{message_prefix} ({len(self.failures)} failures recorded); "
f"for example, for {value}:\n"
f"{traceback_str}\n"
f"{exception}")
msg = f"{message_prefix} ({len(self.failures)} failures recorded)"
super().__init__(msg)
def get_tracebacks(self) -> Dict[Any, str]:
@ -316,8 +306,13 @@ class WorkerPool:
with self._results_lock:
self._failures[result.value] = result.exc_info
if success_event_reached:
# no need to continue processing results
self.cancel() # to cancel the timeout thread
break
if producer_stopped and self._finished_tasks == self._started_tasks:
self.cancel() # to cancel the timeout thread
self.cancel() # to cancel the timeout thread
self._target_value.set(PRODUCER_STOPPED)
break

View File

@ -66,7 +66,7 @@ the Pipe for PRE Application network operations
_LONG_LEARNING_DELAY = 30
_ROUNDS_WITHOUT_NODES_AFTER_WHICH_TO_SLOW_DOWN = 25
DEFAULT_EXECUTION_TIMEOUT = 10 # 10s
DEFAULT_EXECUTION_TIMEOUT = 15 # 15s
DEFAULT_PORT = 9155
@ -85,6 +85,7 @@ the Pipe for PRE Application network operations
federated_only: bool = False,
node_class: object = Ursula,
eth_provider_uri: str = None,
execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT,
*args, **kwargs):
self.federated_only = federated_only
@ -104,6 +105,7 @@ the Pipe for PRE Application network operations
super().__init__(save_metadata=True, domain=domain, node_class=node_class, *args, **kwargs)
self.log = Logger(self.__class__.__name__)
self.execution_timeout = execution_timeout
# Controller Interface
self.interface = self._interface_class(porter=self)
@ -137,18 +139,22 @@ the Pipe for PRE Application network operations
raise
self.block_until_number_of_known_nodes_is(quantity,
timeout=self.DEFAULT_EXECUTION_TIMEOUT,
timeout=self.execution_timeout,
learn_on_this_thread=True,
eager=True)
worker_pool = WorkerPool(worker=get_ursula_info,
value_factory=value_factory,
target_successes=quantity,
timeout=self.DEFAULT_EXECUTION_TIMEOUT,
stagger_timeout=1,
threadpool_size=quantity)
timeout=self.execution_timeout,
stagger_timeout=1)
worker_pool.start()
successes = worker_pool.block_until_target_successes()
try:
successes = worker_pool.block_until_target_successes()
finally:
worker_pool.cancel()
# don't wait for it to stop by "joining" - too slow...
ursulas_info = successes.values()
return list(ursulas_info)
@ -170,7 +176,7 @@ the Pipe for PRE Application network operations
if self.federated_only:
sample_size = quantity - (len(include_ursulas) if include_ursulas else 0)
if not self.block_until_number_of_known_nodes_is(sample_size,
timeout=self.DEFAULT_EXECUTION_TIMEOUT,
timeout=self.execution_timeout,
learn_on_this_thread=True):
raise ValueError("Unable to learn about sufficient Ursulas")
return make_federated_staker_reservoir(known_nodes=self.known_nodes,

View File

@ -170,10 +170,6 @@ def test_wait_for_successes_out_of_values(join_worker_pool):
assert 'raise Exception(f"Operator for {value} failed")' in traceback
assert f'Operator for {value} failed' in traceback
# This will be the last line in the displayed traceback;
# That's where the worker actually failed. (Operator for {value} failed)
assert 'raise Exception(f"Operator for {value} failed")' in message
def test_wait_for_successes_timed_out(join_worker_pool):
"""
@ -354,9 +350,12 @@ def test_buggy_factory_raises_on_block():
factory = BuggyFactory(list(outcomes))
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
# So that we got enough successes for `block_until_target_successes()`.
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
# WorkerPool short circuits once it has sufficient successes. Therefore,
# the stagger timeout needs to be less than worker timeout,
# since BuggyFactory only fails if you do a subsequent batch
# Once the subsequent batch is requested, the BuggyFactory returns an error
# causing WorkerPool to fail
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.75)
pool.start()
time.sleep(2) # wait for the stagger timeout to finish

View File

@ -458,6 +458,7 @@ def federated_porter(federated_ursulas):
known_nodes=federated_ursulas,
verify_node_bonding=False,
federated_only=True,
execution_timeout=2,
network_middleware=MockRestMiddleware())
yield porter
porter.stop_learning_loop()
@ -471,6 +472,7 @@ def blockchain_porter(blockchain_ursulas, testerchain, test_registry):
known_nodes=blockchain_ursulas,
eth_provider_uri=TEST_ETH_PROVIDER_URI,
registry=test_registry,
execution_timeout=2,
network_middleware=MockRestMiddleware())
yield porter
porter.stop_learning_loop()

View File

@ -0,0 +1,78 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import json
import sys
from flask import Response, request
from nucypher.control.controllers import WebController
from nucypher.utilities.concurrency import WorkerPoolException
from nucypher.utilities.porter.control.interfaces import PorterInterface
def test_web_controller_handling_worker_pool_exception(mocker):
interface_impl = mocker.Mock()
num_failures = 3
message_prefix = "Execution failed because test designed that way"
def get_ursulas_method(*args, **kwargs):
failures = {}
for i in range(num_failures):
try:
raise ValueError(f'error_{i}')
except BaseException as e:
failures[f"value_{i}"] = sys.exc_info()
raise WorkerPoolException(message_prefix=message_prefix, failures=failures)
interface_impl.get_ursulas.side_effect = get_ursulas_method
controller = WebController(app_name="web_controller_app_test",
crash_on_error=False,
# too lazy to create test schema - use existing one
interface=PorterInterface(porter=interface_impl))
control_transport = controller.make_control_transport()
@control_transport.route('/get_ursulas', methods=['GET'])
def get_ursulas() -> Response:
"""Porter control endpoint for sampling Ursulas on behalf of Alice."""
response = controller(method_name='get_ursulas', control_request=request)
return response
client = controller.test_client()
get_ursulas_params = {
'quantity': 5,
}
response = client.get('/get_ursulas', data=json.dumps(get_ursulas_params))
assert response.status_code == 404
assert response.content_type == 'application/json'
response_data = json.loads(response.data)
assert message_prefix in response_data['result']['failure_message']
response_failures = response_data['result']['failures']
assert len(response_failures) == 3
values = [f"value_{i}" for i in range(num_failures)]
errors = [f"error_{i}" for i in range(num_failures)]
for failure in response_failures:
assert failure['value'] in values
assert failure['error'] in errors
# remove checked entry
values.remove(failure['value'])
errors.remove(failure['error'])