From 0f05d3d4f510142d91f365d7295d1e1146f1334b Mon Sep 17 00:00:00 2001 From: Kieran Prasch Date: Mon, 23 Oct 2023 11:38:20 +0200 Subject: [PATCH] customizes certificate handling request session adapter for ursula p2p services --- nucypher/blockchain/eth/clients.py | 9 +- nucypher/network/middleware.py | 4 +- nucypher/utilities/certs.py | 132 ++++++++++++++--------------- tests/unit/test_memory_certs.py | 20 ++--- 4 files changed, 81 insertions(+), 84 deletions(-) diff --git a/nucypher/blockchain/eth/clients.py b/nucypher/blockchain/eth/clients.py index 9d256ffe2..2f8ca74df 100644 --- a/nucypher/blockchain/eth/clients.py +++ b/nucypher/blockchain/eth/clients.py @@ -1,10 +1,8 @@ - - import os -import time from functools import cached_property from typing import Union +import time from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID from cytoolz.dicttoolz import dissoc from eth_account import Account @@ -63,7 +61,8 @@ LOCAL_CHAINS = { 5777: "Ganache/TestRPC" } -# TODO: This list is incomplete, but it suffices for the moment - See #1857 +# This list is not exhaustive, +# but is sufficient for the current needs of the project. POA_CHAINS = { 4, # Rinkeby 5, # Goerli @@ -572,4 +571,4 @@ class EthereumTesterClient(EthereumClient): return signature_and_stuff['signature'] def parse_transaction_data(self, transaction): - return transaction._certificates # TODO: See https://github.com/ethereum/eth-tester/issues/173 + return transaction.data # See https://github.com/ethereum/eth-tester/issues/173 diff --git a/nucypher/network/middleware.py b/nucypher/network/middleware.py index 9e0ae3cd8..c73350b74 100644 --- a/nucypher/network/middleware.py +++ b/nucypher/network/middleware.py @@ -8,7 +8,7 @@ from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata from nucypher import characters from nucypher.blockchain.eth.registry import ContractRegistry -from nucypher.utilities.certs import InMemoryCertSession +from nucypher.utilities.certs import P2PSession from nucypher.utilities.logging import Logger SSL_LOGGER = Logger('ssl-middleware') @@ -27,7 +27,7 @@ MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT = os.getenv( class NucypherMiddlewareClient: - library = InMemoryCertSession() + library = P2PSession() timeout = MIDDLEWARE_DEFAULT_CONNECT_TIMEOUT def __init__( diff --git a/nucypher/utilities/certs.py b/nucypher/utilities/certs.py index e24568a90..2c49546c6 100644 --- a/nucypher/utilities/certs.py +++ b/nucypher/utilities/certs.py @@ -1,18 +1,15 @@ import socket import ssl -from _socket import gethostbyname from typing import NamedTuple, Dict from urllib.parse import urlparse, urlunparse import time -from cryptography import x509 -from cryptography.hazmat.backends import default_backend +from _socket import gethostbyname from requests import Session, Response, PreparedRequest from requests.adapters import HTTPAdapter from requests.exceptions import RequestException from urllib3 import PoolManager -from nucypher.crypto.tls import _read_tls_certificate from nucypher.utilities import logging Certificate = str @@ -23,9 +20,38 @@ class Address(NamedTuple): port: int -class CertificateCache: +def _replace_hostname_with_ip(url: str, ip_address: str) -> str: + """Replace the hostname in the URL with the provided IP address.""" + parsed_url = urlparse(url) + return urlunparse(( + parsed_url.scheme, + f"{ip_address}:{parsed_url.port or ''}", + parsed_url.path, + parsed_url.params, + parsed_url.query, + parsed_url.fragment + )) - DEFAULT_DURATION = 3600 + +def _fetch_server_cert(address: Address) -> Certificate: + """Fetch the server certificate from the given address.""" + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + with socket.create_connection(address) as sock: + with context.wrap_socket(sock, server_hostname=address.hostname) as ssock: + sock.close() # close the insecure socket + certificate_bin = ssock.getpeercert(binary_form=True) + + certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin)) + return certificate + + +class CertificateCache: + """Cache for https certificates.""" + + DEFAULT_DURATION = 3600 # seconds DEFAULT_REFRESH_INTERVAL = 600 def __init__( @@ -58,118 +84,90 @@ class CertificateCache: ) -class InMemoryCertAdapter(HTTPAdapter): +class SelfSignedCertificateAdapter(HTTPAdapter): + """An adapter that verifies self-signed certificates in memory only.""" + log = logging.Logger(__name__) def __init__(self, *args, **kwargs): self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self.ssl_context.verify_mode = ssl.CERT_REQUIRED # Enforce certificate verification - self.ssl_context.check_hostname = False # Disable hostname checking + self.ssl_context.verify_mode = ssl.CERT_REQUIRED + self.ssl_context.check_hostname = True super().__init__(*args, **kwargs) def init_poolmanager(self, *args, **kwargs) -> None: + """Override the default poolmanager to use the local SSL context.""" self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs) - def accept_certificate(self, certificate: Certificate) -> None: + def trust_certificate(self, certificate: Certificate) -> None: + """Accept the given certificate as trusted.""" try: self.ssl_context.load_verify_locations(cadata=certificate) except ssl.SSLError as e: self.log.debug(f"Failed to load certificate {e}.") -class InMemoryCertSession(Session): - +class P2PSession(Session): _DEFAULT_HOSTNAME = '' _DEFAULT_PORT = 443 def __init__(self): super().__init__() - self.adapter = InMemoryCertAdapter() + self.adapter = SelfSignedCertificateAdapter() self.cache = CertificateCache() self.mount("https://", self.adapter) @classmethod - def _parse_url(cls, url) -> Address: + def _resolve_address(cls, url) -> Address: + """parse the URL and return the hostname and port as an Address named tuple.""" parsed = urlparse(url) + hostname = parsed.hostname or cls._DEFAULT_HOSTNAME + hostname = gethostbyname(hostname) # resolve DNS return Address( - parsed.hostname or cls._DEFAULT_HOSTNAME, + hostname, parsed.port or cls._DEFAULT_PORT ) def __retry_send(self, address, request, *args, **kwargs) -> Response: certificate = self._refresh_certificate(address) - self.adapter.accept_certificate(certificate=certificate) + self.adapter.trust_certificate(certificate=certificate) try: return super().send(request, *args, **kwargs) except RequestException as e: self.adapter.log.debug(f"Request failed due to {e}, giving up.") raise - def extract_ip_from_certificate(self, certificate_pem: str) -> str: - """ - Extract IP address from the Subject Alternative Name (SAN) field of the certificate. - """ - certificate = x509.load_pem_x509_certificate(certificate_pem.encode(), default_backend()) - try: - san = certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName) - except x509.ExtensionNotFound: - raise ValueError("No SAN extension found in certificate") - - # Check if an IP address is listed in SAN and return the first one - for ip in san.value.get_values_for_type(x509.IPAddress): - return str(ip) - - def replace_hostname_with_ip(self, url: str, ip_address: str) -> str: - """ - Replace the hostname in the URL with the provided IP address. - """ - parsed_url = urlparse(url) - # Reconstruct the URL with the IP address instead of the hostname - return urlunparse(( - parsed_url.scheme, - f"{ip_address}:{parsed_url.port}", - parsed_url.path, - parsed_url.params, - parsed_url.query, - parsed_url.fragment - )) - def send(self, request: PreparedRequest, *args, **kwargs) -> Response: - address = self._parse_url(url=request.url) - certificate = self.get_or_refresh_certificate(address) - self.adapter.accept_certificate(certificate=certificate) - url = self.replace_hostname_with_ip( + """ + Intercept the request, prefetch the host's certificate, + and redirect the request to the certificate's resolved IP address. + + This embedded DNS resolution is necessary because the host's certificate + may contain an IP address in the Subject Alternative Name (SAN) field, + but the hostname in the URL may not resolve to the same IP address. + """ + + address = self._resolve_address(url=request.url) # resolves dns + certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip + self.adapter.trust_certificate(certificate=certificate) + url = _replace_hostname_with_ip( url=request.url, - ip_address=gethostbyname(address.hostname) + ip_address=address.hostname ) - request.url = url + request.url = url # replace the hostname with the resolved IP address try: return super().send(request, *args, **kwargs) except RequestException as e: self.adapter.log.debug(f"Request failed due to {e}, retrying...") return self.__retry_send(address, request, *args, **kwargs) - def get_or_refresh_certificate(self, address: Address) -> Certificate: + def __get_or_refresh_certificate(self, address: Address) -> Certificate: if self.cache.should_cache_now(address): return self._refresh_certificate(address) certificate = self.cache.get(address) return certificate def _refresh_certificate(self, address: Address) -> Certificate: - certificate = self.__fetch_server_cert(address) + certificate = _fetch_server_cert(address) self.cache.set(address, certificate) return certificate - - @staticmethod - def __fetch_server_cert(address: Address) -> Certificate: - context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - - with socket.create_connection(address) as sock: - with context.wrap_socket(sock, server_hostname=address.hostname) as ssock: - sock.close() # close the insecure socket - certificate_bin = ssock.getpeercert(binary_form=True) - - certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin)) - return certificate diff --git a/tests/unit/test_memory_certs.py b/tests/unit/test_memory_certs.py index 46b59410a..461984d50 100644 --- a/tests/unit/test_memory_certs.py +++ b/tests/unit/test_memory_certs.py @@ -5,14 +5,14 @@ import pytest from requests import Session, RequestException from nucypher.utilities.certs import ( - InMemoryCertAdapter, - InMemoryCertSession, + SelfSignedCertificateAdapter, + P2PSession, CertificateCache, Address ) # Define test URLs -VALID_URL = "https://example.com" +VALID_URL = "https://lynx.nucypher.network:9151/status" INVALID_URL = "https://nonexistent-domain.com" MOCK_CERT = """-----BEGIN CERTIFICATE----- @@ -28,20 +28,20 @@ def cache(): @pytest.fixture def adapter(cache): - _adapter = InMemoryCertAdapter() + _adapter = SelfSignedCertificateAdapter() _adapter.cert_cache = cache return _adapter @pytest.fixture def session(adapter): - s = InMemoryCertSession() + s = P2PSession() s.adapter = adapter # Use the same adapter instance return s def test_init_adapter(cache, adapter): - assert isinstance(adapter, InMemoryCertAdapter) + assert isinstance(adapter, SelfSignedCertificateAdapter) def test_cert_cache_set_get(): @@ -68,7 +68,7 @@ def test_cache_cert(cache): def test_send_request(session, mocker): - mocker.patch.object(InMemoryCertAdapter, 'load_certificate') + mocker.patch.object(SelfSignedCertificateAdapter, 'trust_certificate') mocked_refresh = mocker.patch.object(session, '_refresh_certificate', return_value=MOCK_CERT) mocker.patch.object(Session, 'send', return_value='response') response = session.send(mocker.Mock(url=VALID_URL)) @@ -78,7 +78,7 @@ def test_send_request(session, mocker): def test_https_request_with_cert_caching(): # Create a session with certificate caching - session = InMemoryCertSession() + session = P2PSession() # Send a request (it should succeed) response = session.get(VALID_URL) @@ -91,14 +91,14 @@ def test_https_request_with_cert_caching(): def test_https_request_with_cert_refresh(): # Create a session with certificate caching - session = InMemoryCertSession() + session = P2PSession() # Send a request (it should succeed) response = session.get(VALID_URL) assert response.status_code == 200 # Manually expire the cached certificate - hostname, port = InMemoryCertSession._parse_url(VALID_URL) + hostname, port = P2PSession._resolve_address(VALID_URL) session.cache._expirations[(hostname, port)] = 0 # Send another request to the same URL (it should refresh the certificate)