customizes certificate handling request session adapter for ursula p2p services

remotes/origin/v7.4.x
Kieran Prasch 2023-10-23 11:38:20 +02:00 committed by Derek Pierre
parent 7ad2635b83
commit 0f05d3d4f5
4 changed files with 81 additions and 84 deletions

View File

@ -1,10 +1,8 @@
import os
import time
from functools import cached_property
from typing import Union
import time
from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID
from cytoolz.dicttoolz import dissoc
from eth_account import Account
@ -63,7 +61,8 @@ LOCAL_CHAINS = {
5777: "Ganache/TestRPC"
}
# TODO: This list is incomplete, but it suffices for the moment - See #1857
# This list is not exhaustive,
# but is sufficient for the current needs of the project.
POA_CHAINS = {
4, # Rinkeby
5, # Goerli
@ -572,4 +571,4 @@ class EthereumTesterClient(EthereumClient):
return signature_and_stuff['signature']
def parse_transaction_data(self, transaction):
return transaction._certificates # TODO: See https://github.com/ethereum/eth-tester/issues/173
return transaction.data # See https://github.com/ethereum/eth-tester/issues/173

View File

@ -8,7 +8,7 @@ from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata
from nucypher import characters
from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.utilities.certs import InMemoryCertSession
from nucypher.utilities.certs import P2PSession
from nucypher.utilities.logging import Logger
SSL_LOGGER = Logger('ssl-middleware')
@ -27,7 +27,7 @@ MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT = os.getenv(
class NucypherMiddlewareClient:
library = InMemoryCertSession()
library = P2PSession()
timeout = MIDDLEWARE_DEFAULT_CONNECT_TIMEOUT
def __init__(

View File

@ -1,18 +1,15 @@
import socket
import ssl
from _socket import gethostbyname
from typing import NamedTuple, Dict
from urllib.parse import urlparse, urlunparse
import time
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from _socket import gethostbyname
from requests import Session, Response, PreparedRequest
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from urllib3 import PoolManager
from nucypher.crypto.tls import _read_tls_certificate
from nucypher.utilities import logging
Certificate = str
@ -23,9 +20,38 @@ class Address(NamedTuple):
port: int
class CertificateCache:
def _replace_hostname_with_ip(url: str, ip_address: str) -> str:
"""Replace the hostname in the URL with the provided IP address."""
parsed_url = urlparse(url)
return urlunparse((
parsed_url.scheme,
f"{ip_address}:{parsed_url.port or ''}",
parsed_url.path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment
))
DEFAULT_DURATION = 3600
def _fetch_server_cert(address: Address) -> Certificate:
"""Fetch the server certificate from the given address."""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
with socket.create_connection(address) as sock:
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
sock.close() # close the insecure socket
certificate_bin = ssock.getpeercert(binary_form=True)
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
return certificate
class CertificateCache:
"""Cache for https certificates."""
DEFAULT_DURATION = 3600 # seconds
DEFAULT_REFRESH_INTERVAL = 600
def __init__(
@ -58,118 +84,90 @@ class CertificateCache:
)
class InMemoryCertAdapter(HTTPAdapter):
class SelfSignedCertificateAdapter(HTTPAdapter):
"""An adapter that verifies self-signed certificates in memory only."""
log = logging.Logger(__name__)
def __init__(self, *args, **kwargs):
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED # Enforce certificate verification
self.ssl_context.check_hostname = False # Disable hostname checking
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.check_hostname = True
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs) -> None:
"""Override the default poolmanager to use the local SSL context."""
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
def accept_certificate(self, certificate: Certificate) -> None:
def trust_certificate(self, certificate: Certificate) -> None:
"""Accept the given certificate as trusted."""
try:
self.ssl_context.load_verify_locations(cadata=certificate)
except ssl.SSLError as e:
self.log.debug(f"Failed to load certificate {e}.")
class InMemoryCertSession(Session):
class P2PSession(Session):
_DEFAULT_HOSTNAME = ''
_DEFAULT_PORT = 443
def __init__(self):
super().__init__()
self.adapter = InMemoryCertAdapter()
self.adapter = SelfSignedCertificateAdapter()
self.cache = CertificateCache()
self.mount("https://", self.adapter)
@classmethod
def _parse_url(cls, url) -> Address:
def _resolve_address(cls, url) -> Address:
"""parse the URL and return the hostname and port as an Address named tuple."""
parsed = urlparse(url)
hostname = parsed.hostname or cls._DEFAULT_HOSTNAME
hostname = gethostbyname(hostname) # resolve DNS
return Address(
parsed.hostname or cls._DEFAULT_HOSTNAME,
hostname,
parsed.port or cls._DEFAULT_PORT
)
def __retry_send(self, address, request, *args, **kwargs) -> Response:
certificate = self._refresh_certificate(address)
self.adapter.accept_certificate(certificate=certificate)
self.adapter.trust_certificate(certificate=certificate)
try:
return super().send(request, *args, **kwargs)
except RequestException as e:
self.adapter.log.debug(f"Request failed due to {e}, giving up.")
raise
def extract_ip_from_certificate(self, certificate_pem: str) -> str:
"""
Extract IP address from the Subject Alternative Name (SAN) field of the certificate.
"""
certificate = x509.load_pem_x509_certificate(certificate_pem.encode(), default_backend())
try:
san = certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName)
except x509.ExtensionNotFound:
raise ValueError("No SAN extension found in certificate")
# Check if an IP address is listed in SAN and return the first one
for ip in san.value.get_values_for_type(x509.IPAddress):
return str(ip)
def replace_hostname_with_ip(self, url: str, ip_address: str) -> str:
"""
Replace the hostname in the URL with the provided IP address.
"""
parsed_url = urlparse(url)
# Reconstruct the URL with the IP address instead of the hostname
return urlunparse((
parsed_url.scheme,
f"{ip_address}:{parsed_url.port}",
parsed_url.path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment
))
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
address = self._parse_url(url=request.url)
certificate = self.get_or_refresh_certificate(address)
self.adapter.accept_certificate(certificate=certificate)
url = self.replace_hostname_with_ip(
"""
Intercept the request, prefetch the host's certificate,
and redirect the request to the certificate's resolved IP address.
This embedded DNS resolution is necessary because the host's certificate
may contain an IP address in the Subject Alternative Name (SAN) field,
but the hostname in the URL may not resolve to the same IP address.
"""
address = self._resolve_address(url=request.url) # resolves dns
certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip
self.adapter.trust_certificate(certificate=certificate)
url = _replace_hostname_with_ip(
url=request.url,
ip_address=gethostbyname(address.hostname)
ip_address=address.hostname
)
request.url = url
request.url = url # replace the hostname with the resolved IP address
try:
return super().send(request, *args, **kwargs)
except RequestException as e:
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
return self.__retry_send(address, request, *args, **kwargs)
def get_or_refresh_certificate(self, address: Address) -> Certificate:
def __get_or_refresh_certificate(self, address: Address) -> Certificate:
if self.cache.should_cache_now(address):
return self._refresh_certificate(address)
certificate = self.cache.get(address)
return certificate
def _refresh_certificate(self, address: Address) -> Certificate:
certificate = self.__fetch_server_cert(address)
certificate = _fetch_server_cert(address)
self.cache.set(address, certificate)
return certificate
@staticmethod
def __fetch_server_cert(address: Address) -> Certificate:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
with socket.create_connection(address) as sock:
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
sock.close() # close the insecure socket
certificate_bin = ssock.getpeercert(binary_form=True)
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
return certificate

View File

@ -5,14 +5,14 @@ import pytest
from requests import Session, RequestException
from nucypher.utilities.certs import (
InMemoryCertAdapter,
InMemoryCertSession,
SelfSignedCertificateAdapter,
P2PSession,
CertificateCache,
Address
)
# Define test URLs
VALID_URL = "https://example.com"
VALID_URL = "https://lynx.nucypher.network:9151/status"
INVALID_URL = "https://nonexistent-domain.com"
MOCK_CERT = """-----BEGIN CERTIFICATE-----
@ -28,20 +28,20 @@ def cache():
@pytest.fixture
def adapter(cache):
_adapter = InMemoryCertAdapter()
_adapter = SelfSignedCertificateAdapter()
_adapter.cert_cache = cache
return _adapter
@pytest.fixture
def session(adapter):
s = InMemoryCertSession()
s = P2PSession()
s.adapter = adapter # Use the same adapter instance
return s
def test_init_adapter(cache, adapter):
assert isinstance(adapter, InMemoryCertAdapter)
assert isinstance(adapter, SelfSignedCertificateAdapter)
def test_cert_cache_set_get():
@ -68,7 +68,7 @@ def test_cache_cert(cache):
def test_send_request(session, mocker):
mocker.patch.object(InMemoryCertAdapter, 'load_certificate')
mocker.patch.object(SelfSignedCertificateAdapter, 'trust_certificate')
mocked_refresh = mocker.patch.object(session, '_refresh_certificate', return_value=MOCK_CERT)
mocker.patch.object(Session, 'send', return_value='response')
response = session.send(mocker.Mock(url=VALID_URL))
@ -78,7 +78,7 @@ def test_send_request(session, mocker):
def test_https_request_with_cert_caching():
# Create a session with certificate caching
session = InMemoryCertSession()
session = P2PSession()
# Send a request (it should succeed)
response = session.get(VALID_URL)
@ -91,14 +91,14 @@ def test_https_request_with_cert_caching():
def test_https_request_with_cert_refresh():
# Create a session with certificate caching
session = InMemoryCertSession()
session = P2PSession()
# Send a request (it should succeed)
response = session.get(VALID_URL)
assert response.status_code == 200
# Manually expire the cached certificate
hostname, port = InMemoryCertSession._parse_url(VALID_URL)
hostname, port = P2PSession._resolve_address(VALID_URL)
session.cache._expirations[(hostname, port)] = 0
# Send another request to the same URL (it should refresh the certificate)