Makes network middleware loally certificate agnostic

remotes/origin/v7.4.x
Kieran Prasch 2023-10-22 00:47:54 +02:00 committed by Derek Pierre
parent 13758d14a0
commit 1b6b61e334
6 changed files with 83 additions and 94 deletions

View File

@ -1213,22 +1213,9 @@ class Ursula(Teacher, Character, Operator):
# Parse node URI
host, port, staking_provider_address = parse_node_uri(seed_uri)
# Fetch the hosts TLS certificate and read the common name
try:
certificate, _filepath = network_middleware.client.get_certificate(
host=host, port=port
)
except NodeSeemsToBeDown as e:
e.args += (f"While trying to load seednode {seed_uri}",)
e.crash_right_now = True
raise
real_host = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[
0
].value
# Load the host as a potential seed node
potential_seed_node = cls.from_rest_url(
host=real_host,
host=host,
port=port,
network_middleware=network_middleware,
)

View File

@ -1,24 +1,15 @@
import os
import socket
import ssl
from http import HTTPStatus
<<<<<<< HEAD
from typing import Tuple, Union
=======
from typing import Optional, Sequence
>>>>>>> 8233d2ce2 (removes certificate filepath handling)
from typing import Tuple, Union
import time
from constant_sorrow.constants import EXEMPT_FROM_VERIFICATION
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata
from nucypher import characters
from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.config.storages import NodeStorage
from nucypher.utilities.logging import Logger
from nucypher.utilities.certs import InMemoryCertSession
from nucypher.utilities.logging import Logger
SSL_LOGGER = Logger('ssl-middleware')
EXEMPT_FROM_VERIFICATION.bool_value(False)
@ -43,7 +34,6 @@ class NucypherMiddlewareClient:
self,
eth_endpoint: Optional[str],
registry: Optional[ContractRegistry] = None,
storage: Optional[NodeStorage] = None,
*args,
**kwargs,
):
@ -52,39 +42,6 @@ class NucypherMiddlewareClient:
self.registry = registry
self.eth_endpoint = eth_endpoint
self.storage = storage or NodeStorage() # for certificate storage
def get_certificate(
self,
host,
port,
timeout=MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT,
retry_attempts: int = 3,
retry_rate: int = 2,
current_attempt: int = 0,
):
socket.setdefaulttimeout(timeout) # Set Socket Timeout
try:
SSL_LOGGER.debug(f"Fetching {host}:{port} TLS certificate")
certificate_pem = ssl.get_server_certificate(addr=(host, port))
certificate = ssl.PEM_cert_to_DER_cert(certificate_pem)
except socket.timeout:
if current_attempt == retry_attempts:
message = f"No Response from {host}:{port} after {retry_attempts} attempts"
SSL_LOGGER.info(message)
raise ConnectionRefusedError("No response from {}:{}".format(host, port))
SSL_LOGGER.info(f"No Response from {host}:{port}. Retrying in {retry_rate} seconds...")
time.sleep(retry_rate)
return self.get_certificate(host, port, timeout, retry_attempts, retry_rate, current_attempt + 1)
except OSError:
raise # TODO: #1835
certificate = x509.load_der_x509_certificate(certificate, backend=default_backend())
filepath = self.storage.store_node_certificate(certificate=certificate, port=port)
return certificate, filepath
@staticmethod
def response_cleaner(response):

View File

@ -1,53 +1,112 @@
import socket
import ssl
from typing import Tuple
from urllib.parse import urlparse
from requests import Session
import time
from requests import Session, Response
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from urllib3.poolmanager import PoolManager
def parse_url(url) -> Tuple[str, int]:
parsed = urlparse(url)
hostname = parsed.hostname or ''
port = parsed.port or 443
return hostname, port
class InMemoryCertAdapter(HTTPAdapter):
def __init__(self, *args, **kwargs):
"""Transport adapter that uses a cached certificate for HTTPS requests"""
def __init__(
self,
cache_duration: int = 3600,
refresh_interval: int = 600,
*args, **kwargs
):
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.cert_cache = {}
self.cache_expiry = {}
self.cache_duration = cache_duration
self.cache_refresh_interval = refresh_interval
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
def init_poolmanager(self, *args, **kwargs) -> None:
"""Override the default poolmanager to use the local SSL context"""
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
def update_cert(self, cert_pem):
def set_active_cert(self, cert_pem: str) -> None:
"""Set the active certificate for the SSL context"""
self.ssl_context.load_verify_locations(cadata=cert_pem)
def get_cached_cert(self, hostname: str, port: int) -> str:
return self.cert_cache.get((hostname, port))
def set_cached_cert(self, hostname: str, port: int, cert_pem: str) -> None:
self.cert_cache[(hostname, port)] = cert_pem
self.cache_expiry[(hostname, port)] = time.time() + self.cache_duration
def is_cert_expired(self, hostname: str, port: int) -> bool:
return ((hostname, port) in self.cache_expiry
and time.time() > self.cache_expiry[(hostname, port)])
def should_cache_now(self, hostname: str, port: int) -> bool:
return (
(hostname, port) not in self.cache_expiry
or time.time()
> self.cache_expiry[(hostname, port)]
- self.cache_refresh_interval
)
class InMemoryCertSession(Session):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self):
super().__init__()
self.adapter = InMemoryCertAdapter()
self.mount("https://", self.adapter)
def send(self, *args, **kwargs):
def send(self, *args, **kwargs) -> Response:
"""
Override the send to check if the certificate should be refreshed
and to refresh it if needed before sending the request.
"""
# Parse the URL to extract hostname and port
url = kwargs.get('url', args[0].url)
parsed = urlparse(url)
hostname = parsed.hostname or ''
port = parsed.port or 443
hostname, port = parse_url(url=url)
# Fetch and update the certificate
cert_pem = self.fetch_server_cert(hostname, port)
self.adapter.update_cert(cert_pem)
if self.adapter.should_cache_now(hostname=hostname, port=port):
self.refresh_certificate(hostname=hostname, port=port)
# Perform the actual request
return super().send(*args, **kwargs)
try:
# Perform the actual request
response = super().send(*args, **kwargs)
except RequestException:
# reconnect and retry once
self.refresh_certificate(hostname=hostname, port=port)
response = super().send(*args, **kwargs)
return response
def refresh_certificate(self, hostname: str, port: int) -> None:
cert_pem = self.__fetch_server_cert(hostname, port)
self.adapter.set_cached_cert(
hostname=hostname,
port=port,
cert_pem=cert_pem
)
@staticmethod
def fetch_server_cert(hostname, port):
def __fetch_server_cert(hostname: str, port: int) -> str:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
with socket.create_connection((hostname, port)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert_bin = ssock.getpeercert(True)
with context.wrap_socket(sock, server_hostname=hostname) as wrapped_sock:
cert_bin = wrapped_sock.getpeercert(binary_form=True)
cert_pem = ssl.DER_cert_to_PEM_cert(cert_bin)
return cert_pem

View File

@ -103,14 +103,6 @@ def temp_dir_path():
yield Path(temp_dir.name)
temp_dir.cleanup()
@pytest.fixture(scope='function')
def certificates_tempdir():
custom_filepath = '/tmp/nucypher-test-certificates-'
cert_tmpdir = tempfile.TemporaryDirectory(prefix=custom_filepath)
yield Path(cert_tmpdir.name)
cert_tmpdir.cleanup()
#
# Accounts
#
@ -389,7 +381,6 @@ def get_random_checksum_address():
@pytest.fixture(scope="module")
def fleet_of_highperf_mocked_ursulas(ursula_test_config, request, testerchain):
mocks = (
mock_cert_storage,
mock_cert_loading,
mock_rest_app_creation,
mock_cert_generation,
@ -452,7 +443,7 @@ def highperf_mocked_alice(
reload_metadata=False,
)
with mock_cert_storage, mock_verify_node, mock_message_verification, mock_keep_learning:
with mock_verify_node, mock_message_verification, mock_keep_learning:
alice = config.produce(known_nodes=list(fleet_of_highperf_mocked_ursulas)[:1])
yield alice
# TODO: Where does this really, truly belong?
@ -473,7 +464,7 @@ def highperf_mocked_bob(fleet_of_highperf_mocked_ursulas):
reload_metadata=False,
)
with mock_cert_storage, mock_verify_node, mock_record_fleet_state, mock_keep_learning:
with mock_verify_node, mock_record_fleet_state, mock_keep_learning:
bob = config.produce(known_nodes=list(fleet_of_highperf_mocked_ursulas)[:1])
yield bob
bob._learning_task.stop()

View File

@ -89,7 +89,6 @@ def mock_requests(mocker):
@pytest.fixture(autouse=True)
def mock_client(mocker):
cert, pk = generate_self_signed_certificate(host=MOCK_IP_ADDRESS)
mocker.patch.object(NucypherMiddlewareClient, 'get_certificate', return_value=(cert, Path()))
yield mocker.patch.object(NucypherMiddlewareClient, 'invoke_method', return_value=Dummy.GoodResponse)

View File

@ -68,10 +68,6 @@ class _TestMiddlewareClient(NucypherMiddlewareClient):
def clean_params(self, request_kwargs):
request_kwargs["query_string"] = request_kwargs.pop("params", {})
def get_certificate(self, port, *args, **kwargs):
ursula = self._get_ursula_by_port(port)
return ursula.certificate, Path()
class MockRestMiddleware(RestMiddleware):
_ursulas = None