mirror of https://github.com/nucypher/nucypher.git
customizes certificate handling request session adapter for ursula p2p services
parent
7ad2635b83
commit
0f05d3d4f5
|
@ -1,10 +1,8 @@
|
|||
|
||||
|
||||
import os
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import time
|
||||
from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID
|
||||
from cytoolz.dicttoolz import dissoc
|
||||
from eth_account import Account
|
||||
|
@ -63,7 +61,8 @@ LOCAL_CHAINS = {
|
|||
5777: "Ganache/TestRPC"
|
||||
}
|
||||
|
||||
# TODO: This list is incomplete, but it suffices for the moment - See #1857
|
||||
# This list is not exhaustive,
|
||||
# but is sufficient for the current needs of the project.
|
||||
POA_CHAINS = {
|
||||
4, # Rinkeby
|
||||
5, # Goerli
|
||||
|
@ -572,4 +571,4 @@ class EthereumTesterClient(EthereumClient):
|
|||
return signature_and_stuff['signature']
|
||||
|
||||
def parse_transaction_data(self, transaction):
|
||||
return transaction._certificates # TODO: See https://github.com/ethereum/eth-tester/issues/173
|
||||
return transaction.data # See https://github.com/ethereum/eth-tester/issues/173
|
||||
|
|
|
@ -8,7 +8,7 @@ from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata
|
|||
|
||||
from nucypher import characters
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry
|
||||
from nucypher.utilities.certs import InMemoryCertSession
|
||||
from nucypher.utilities.certs import P2PSession
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
SSL_LOGGER = Logger('ssl-middleware')
|
||||
|
@ -27,7 +27,7 @@ MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT = os.getenv(
|
|||
|
||||
|
||||
class NucypherMiddlewareClient:
|
||||
library = InMemoryCertSession()
|
||||
library = P2PSession()
|
||||
timeout = MIDDLEWARE_DEFAULT_CONNECT_TIMEOUT
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -1,18 +1,15 @@
|
|||
import socket
|
||||
import ssl
|
||||
from _socket import gethostbyname
|
||||
from typing import NamedTuple, Dict
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import time
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from _socket import gethostbyname
|
||||
from requests import Session, Response, PreparedRequest
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import RequestException
|
||||
from urllib3 import PoolManager
|
||||
|
||||
from nucypher.crypto.tls import _read_tls_certificate
|
||||
from nucypher.utilities import logging
|
||||
|
||||
Certificate = str
|
||||
|
@ -23,9 +20,38 @@ class Address(NamedTuple):
|
|||
port: int
|
||||
|
||||
|
||||
class CertificateCache:
|
||||
def _replace_hostname_with_ip(url: str, ip_address: str) -> str:
|
||||
"""Replace the hostname in the URL with the provided IP address."""
|
||||
parsed_url = urlparse(url)
|
||||
return urlunparse((
|
||||
parsed_url.scheme,
|
||||
f"{ip_address}:{parsed_url.port or ''}",
|
||||
parsed_url.path,
|
||||
parsed_url.params,
|
||||
parsed_url.query,
|
||||
parsed_url.fragment
|
||||
))
|
||||
|
||||
DEFAULT_DURATION = 3600
|
||||
|
||||
def _fetch_server_cert(address: Address) -> Certificate:
|
||||
"""Fetch the server certificate from the given address."""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
with socket.create_connection(address) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
|
||||
sock.close() # close the insecure socket
|
||||
certificate_bin = ssock.getpeercert(binary_form=True)
|
||||
|
||||
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
|
||||
return certificate
|
||||
|
||||
|
||||
class CertificateCache:
|
||||
"""Cache for https certificates."""
|
||||
|
||||
DEFAULT_DURATION = 3600 # seconds
|
||||
DEFAULT_REFRESH_INTERVAL = 600
|
||||
|
||||
def __init__(
|
||||
|
@ -58,118 +84,90 @@ class CertificateCache:
|
|||
)
|
||||
|
||||
|
||||
class InMemoryCertAdapter(HTTPAdapter):
|
||||
class SelfSignedCertificateAdapter(HTTPAdapter):
|
||||
"""An adapter that verifies self-signed certificates in memory only."""
|
||||
|
||||
log = logging.Logger(__name__)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED # Enforce certificate verification
|
||||
self.ssl_context.check_hostname = False # Disable hostname checking
|
||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.ssl_context.check_hostname = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_poolmanager(self, *args, **kwargs) -> None:
|
||||
"""Override the default poolmanager to use the local SSL context."""
|
||||
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
|
||||
|
||||
def accept_certificate(self, certificate: Certificate) -> None:
|
||||
def trust_certificate(self, certificate: Certificate) -> None:
|
||||
"""Accept the given certificate as trusted."""
|
||||
try:
|
||||
self.ssl_context.load_verify_locations(cadata=certificate)
|
||||
except ssl.SSLError as e:
|
||||
self.log.debug(f"Failed to load certificate {e}.")
|
||||
|
||||
|
||||
class InMemoryCertSession(Session):
|
||||
|
||||
class P2PSession(Session):
|
||||
_DEFAULT_HOSTNAME = ''
|
||||
_DEFAULT_PORT = 443
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter = InMemoryCertAdapter()
|
||||
self.adapter = SelfSignedCertificateAdapter()
|
||||
self.cache = CertificateCache()
|
||||
self.mount("https://", self.adapter)
|
||||
|
||||
@classmethod
|
||||
def _parse_url(cls, url) -> Address:
|
||||
def _resolve_address(cls, url) -> Address:
|
||||
"""parse the URL and return the hostname and port as an Address named tuple."""
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname or cls._DEFAULT_HOSTNAME
|
||||
hostname = gethostbyname(hostname) # resolve DNS
|
||||
return Address(
|
||||
parsed.hostname or cls._DEFAULT_HOSTNAME,
|
||||
hostname,
|
||||
parsed.port or cls._DEFAULT_PORT
|
||||
)
|
||||
|
||||
def __retry_send(self, address, request, *args, **kwargs) -> Response:
|
||||
certificate = self._refresh_certificate(address)
|
||||
self.adapter.accept_certificate(certificate=certificate)
|
||||
self.adapter.trust_certificate(certificate=certificate)
|
||||
try:
|
||||
return super().send(request, *args, **kwargs)
|
||||
except RequestException as e:
|
||||
self.adapter.log.debug(f"Request failed due to {e}, giving up.")
|
||||
raise
|
||||
|
||||
def extract_ip_from_certificate(self, certificate_pem: str) -> str:
|
||||
"""
|
||||
Extract IP address from the Subject Alternative Name (SAN) field of the certificate.
|
||||
"""
|
||||
certificate = x509.load_pem_x509_certificate(certificate_pem.encode(), default_backend())
|
||||
try:
|
||||
san = certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
||||
except x509.ExtensionNotFound:
|
||||
raise ValueError("No SAN extension found in certificate")
|
||||
|
||||
# Check if an IP address is listed in SAN and return the first one
|
||||
for ip in san.value.get_values_for_type(x509.IPAddress):
|
||||
return str(ip)
|
||||
|
||||
def replace_hostname_with_ip(self, url: str, ip_address: str) -> str:
|
||||
"""
|
||||
Replace the hostname in the URL with the provided IP address.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
# Reconstruct the URL with the IP address instead of the hostname
|
||||
return urlunparse((
|
||||
parsed_url.scheme,
|
||||
f"{ip_address}:{parsed_url.port}",
|
||||
parsed_url.path,
|
||||
parsed_url.params,
|
||||
parsed_url.query,
|
||||
parsed_url.fragment
|
||||
))
|
||||
|
||||
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
|
||||
address = self._parse_url(url=request.url)
|
||||
certificate = self.get_or_refresh_certificate(address)
|
||||
self.adapter.accept_certificate(certificate=certificate)
|
||||
url = self.replace_hostname_with_ip(
|
||||
"""
|
||||
Intercept the request, prefetch the host's certificate,
|
||||
and redirect the request to the certificate's resolved IP address.
|
||||
|
||||
This embedded DNS resolution is necessary because the host's certificate
|
||||
may contain an IP address in the Subject Alternative Name (SAN) field,
|
||||
but the hostname in the URL may not resolve to the same IP address.
|
||||
"""
|
||||
|
||||
address = self._resolve_address(url=request.url) # resolves dns
|
||||
certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip
|
||||
self.adapter.trust_certificate(certificate=certificate)
|
||||
url = _replace_hostname_with_ip(
|
||||
url=request.url,
|
||||
ip_address=gethostbyname(address.hostname)
|
||||
ip_address=address.hostname
|
||||
)
|
||||
request.url = url
|
||||
request.url = url # replace the hostname with the resolved IP address
|
||||
try:
|
||||
return super().send(request, *args, **kwargs)
|
||||
except RequestException as e:
|
||||
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
|
||||
return self.__retry_send(address, request, *args, **kwargs)
|
||||
|
||||
def get_or_refresh_certificate(self, address: Address) -> Certificate:
|
||||
def __get_or_refresh_certificate(self, address: Address) -> Certificate:
|
||||
if self.cache.should_cache_now(address):
|
||||
return self._refresh_certificate(address)
|
||||
certificate = self.cache.get(address)
|
||||
return certificate
|
||||
|
||||
def _refresh_certificate(self, address: Address) -> Certificate:
|
||||
certificate = self.__fetch_server_cert(address)
|
||||
certificate = _fetch_server_cert(address)
|
||||
self.cache.set(address, certificate)
|
||||
return certificate
|
||||
|
||||
@staticmethod
|
||||
def __fetch_server_cert(address: Address) -> Certificate:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
with socket.create_connection(address) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
|
||||
sock.close() # close the insecure socket
|
||||
certificate_bin = ssock.getpeercert(binary_form=True)
|
||||
|
||||
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
|
||||
return certificate
|
||||
|
|
|
@ -5,14 +5,14 @@ import pytest
|
|||
from requests import Session, RequestException
|
||||
|
||||
from nucypher.utilities.certs import (
|
||||
InMemoryCertAdapter,
|
||||
InMemoryCertSession,
|
||||
SelfSignedCertificateAdapter,
|
||||
P2PSession,
|
||||
CertificateCache,
|
||||
Address
|
||||
)
|
||||
|
||||
# Define test URLs
|
||||
VALID_URL = "https://example.com"
|
||||
VALID_URL = "https://lynx.nucypher.network:9151/status"
|
||||
INVALID_URL = "https://nonexistent-domain.com"
|
||||
|
||||
MOCK_CERT = """-----BEGIN CERTIFICATE-----
|
||||
|
@ -28,20 +28,20 @@ def cache():
|
|||
|
||||
@pytest.fixture
|
||||
def adapter(cache):
|
||||
_adapter = InMemoryCertAdapter()
|
||||
_adapter = SelfSignedCertificateAdapter()
|
||||
_adapter.cert_cache = cache
|
||||
return _adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(adapter):
|
||||
s = InMemoryCertSession()
|
||||
s = P2PSession()
|
||||
s.adapter = adapter # Use the same adapter instance
|
||||
return s
|
||||
|
||||
|
||||
def test_init_adapter(cache, adapter):
|
||||
assert isinstance(adapter, InMemoryCertAdapter)
|
||||
assert isinstance(adapter, SelfSignedCertificateAdapter)
|
||||
|
||||
|
||||
def test_cert_cache_set_get():
|
||||
|
@ -68,7 +68,7 @@ def test_cache_cert(cache):
|
|||
|
||||
|
||||
def test_send_request(session, mocker):
|
||||
mocker.patch.object(InMemoryCertAdapter, 'load_certificate')
|
||||
mocker.patch.object(SelfSignedCertificateAdapter, 'trust_certificate')
|
||||
mocked_refresh = mocker.patch.object(session, '_refresh_certificate', return_value=MOCK_CERT)
|
||||
mocker.patch.object(Session, 'send', return_value='response')
|
||||
response = session.send(mocker.Mock(url=VALID_URL))
|
||||
|
@ -78,7 +78,7 @@ def test_send_request(session, mocker):
|
|||
|
||||
def test_https_request_with_cert_caching():
|
||||
# Create a session with certificate caching
|
||||
session = InMemoryCertSession()
|
||||
session = P2PSession()
|
||||
|
||||
# Send a request (it should succeed)
|
||||
response = session.get(VALID_URL)
|
||||
|
@ -91,14 +91,14 @@ def test_https_request_with_cert_caching():
|
|||
|
||||
def test_https_request_with_cert_refresh():
|
||||
# Create a session with certificate caching
|
||||
session = InMemoryCertSession()
|
||||
session = P2PSession()
|
||||
|
||||
# Send a request (it should succeed)
|
||||
response = session.get(VALID_URL)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Manually expire the cached certificate
|
||||
hostname, port = InMemoryCertSession._parse_url(VALID_URL)
|
||||
hostname, port = P2PSession._resolve_address(VALID_URL)
|
||||
session.cache._expirations[(hostname, port)] = 0
|
||||
|
||||
# Send another request to the same URL (it should refresh the certificate)
|
||||
|
|
Loading…
Reference in New Issue