mirror of https://github.com/nucypher/nucypher.git
customizes certificate handling request session adapter for ursula p2p services
parent
7ad2635b83
commit
0f05d3d4f5
|
@ -1,10 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import time
|
||||||
from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID
|
from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID
|
||||||
from cytoolz.dicttoolz import dissoc
|
from cytoolz.dicttoolz import dissoc
|
||||||
from eth_account import Account
|
from eth_account import Account
|
||||||
|
@ -63,7 +61,8 @@ LOCAL_CHAINS = {
|
||||||
5777: "Ganache/TestRPC"
|
5777: "Ganache/TestRPC"
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: This list is incomplete, but it suffices for the moment - See #1857
|
# This list is not exhaustive,
|
||||||
|
# but is sufficient for the current needs of the project.
|
||||||
POA_CHAINS = {
|
POA_CHAINS = {
|
||||||
4, # Rinkeby
|
4, # Rinkeby
|
||||||
5, # Goerli
|
5, # Goerli
|
||||||
|
@ -572,4 +571,4 @@ class EthereumTesterClient(EthereumClient):
|
||||||
return signature_and_stuff['signature']
|
return signature_and_stuff['signature']
|
||||||
|
|
||||||
def parse_transaction_data(self, transaction):
|
def parse_transaction_data(self, transaction):
|
||||||
return transaction._certificates # TODO: See https://github.com/ethereum/eth-tester/issues/173
|
return transaction.data # See https://github.com/ethereum/eth-tester/issues/173
|
||||||
|
|
|
@ -8,7 +8,7 @@ from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata
|
||||||
|
|
||||||
from nucypher import characters
|
from nucypher import characters
|
||||||
from nucypher.blockchain.eth.registry import ContractRegistry
|
from nucypher.blockchain.eth.registry import ContractRegistry
|
||||||
from nucypher.utilities.certs import InMemoryCertSession
|
from nucypher.utilities.certs import P2PSession
|
||||||
from nucypher.utilities.logging import Logger
|
from nucypher.utilities.logging import Logger
|
||||||
|
|
||||||
SSL_LOGGER = Logger('ssl-middleware')
|
SSL_LOGGER = Logger('ssl-middleware')
|
||||||
|
@ -27,7 +27,7 @@ MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT = os.getenv(
|
||||||
|
|
||||||
|
|
||||||
class NucypherMiddlewareClient:
|
class NucypherMiddlewareClient:
|
||||||
library = InMemoryCertSession()
|
library = P2PSession()
|
||||||
timeout = MIDDLEWARE_DEFAULT_CONNECT_TIMEOUT
|
timeout = MIDDLEWARE_DEFAULT_CONNECT_TIMEOUT
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -1,18 +1,15 @@
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
from _socket import gethostbyname
|
|
||||||
from typing import NamedTuple, Dict
|
from typing import NamedTuple, Dict
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from cryptography import x509
|
from _socket import gethostbyname
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from requests import Session, Response, PreparedRequest
|
from requests import Session, Response, PreparedRequest
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
from urllib3 import PoolManager
|
from urllib3 import PoolManager
|
||||||
|
|
||||||
from nucypher.crypto.tls import _read_tls_certificate
|
|
||||||
from nucypher.utilities import logging
|
from nucypher.utilities import logging
|
||||||
|
|
||||||
Certificate = str
|
Certificate = str
|
||||||
|
@ -23,9 +20,38 @@ class Address(NamedTuple):
|
||||||
port: int
|
port: int
|
||||||
|
|
||||||
|
|
||||||
class CertificateCache:
|
def _replace_hostname_with_ip(url: str, ip_address: str) -> str:
|
||||||
|
"""Replace the hostname in the URL with the provided IP address."""
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
return urlunparse((
|
||||||
|
parsed_url.scheme,
|
||||||
|
f"{ip_address}:{parsed_url.port or ''}",
|
||||||
|
parsed_url.path,
|
||||||
|
parsed_url.params,
|
||||||
|
parsed_url.query,
|
||||||
|
parsed_url.fragment
|
||||||
|
))
|
||||||
|
|
||||||
DEFAULT_DURATION = 3600
|
|
||||||
|
def _fetch_server_cert(address: Address) -> Certificate:
|
||||||
|
"""Fetch the server certificate from the given address."""
|
||||||
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
context.check_hostname = False
|
||||||
|
context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
with socket.create_connection(address) as sock:
|
||||||
|
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
|
||||||
|
sock.close() # close the insecure socket
|
||||||
|
certificate_bin = ssock.getpeercert(binary_form=True)
|
||||||
|
|
||||||
|
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
|
||||||
|
return certificate
|
||||||
|
|
||||||
|
|
||||||
|
class CertificateCache:
|
||||||
|
"""Cache for https certificates."""
|
||||||
|
|
||||||
|
DEFAULT_DURATION = 3600 # seconds
|
||||||
DEFAULT_REFRESH_INTERVAL = 600
|
DEFAULT_REFRESH_INTERVAL = 600
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -58,118 +84,90 @@ class CertificateCache:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCertAdapter(HTTPAdapter):
|
class SelfSignedCertificateAdapter(HTTPAdapter):
|
||||||
|
"""An adapter that verifies self-signed certificates in memory only."""
|
||||||
|
|
||||||
log = logging.Logger(__name__)
|
log = logging.Logger(__name__)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED # Enforce certificate verification
|
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||||
self.ssl_context.check_hostname = False # Disable hostname checking
|
self.ssl_context.check_hostname = True
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def init_poolmanager(self, *args, **kwargs) -> None:
|
def init_poolmanager(self, *args, **kwargs) -> None:
|
||||||
|
"""Override the default poolmanager to use the local SSL context."""
|
||||||
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
|
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
|
||||||
|
|
||||||
def accept_certificate(self, certificate: Certificate) -> None:
|
def trust_certificate(self, certificate: Certificate) -> None:
|
||||||
|
"""Accept the given certificate as trusted."""
|
||||||
try:
|
try:
|
||||||
self.ssl_context.load_verify_locations(cadata=certificate)
|
self.ssl_context.load_verify_locations(cadata=certificate)
|
||||||
except ssl.SSLError as e:
|
except ssl.SSLError as e:
|
||||||
self.log.debug(f"Failed to load certificate {e}.")
|
self.log.debug(f"Failed to load certificate {e}.")
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCertSession(Session):
|
class P2PSession(Session):
|
||||||
|
|
||||||
_DEFAULT_HOSTNAME = ''
|
_DEFAULT_HOSTNAME = ''
|
||||||
_DEFAULT_PORT = 443
|
_DEFAULT_PORT = 443
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.adapter = InMemoryCertAdapter()
|
self.adapter = SelfSignedCertificateAdapter()
|
||||||
self.cache = CertificateCache()
|
self.cache = CertificateCache()
|
||||||
self.mount("https://", self.adapter)
|
self.mount("https://", self.adapter)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_url(cls, url) -> Address:
|
def _resolve_address(cls, url) -> Address:
|
||||||
|
"""parse the URL and return the hostname and port as an Address named tuple."""
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
|
hostname = parsed.hostname or cls._DEFAULT_HOSTNAME
|
||||||
|
hostname = gethostbyname(hostname) # resolve DNS
|
||||||
return Address(
|
return Address(
|
||||||
parsed.hostname or cls._DEFAULT_HOSTNAME,
|
hostname,
|
||||||
parsed.port or cls._DEFAULT_PORT
|
parsed.port or cls._DEFAULT_PORT
|
||||||
)
|
)
|
||||||
|
|
||||||
def __retry_send(self, address, request, *args, **kwargs) -> Response:
|
def __retry_send(self, address, request, *args, **kwargs) -> Response:
|
||||||
certificate = self._refresh_certificate(address)
|
certificate = self._refresh_certificate(address)
|
||||||
self.adapter.accept_certificate(certificate=certificate)
|
self.adapter.trust_certificate(certificate=certificate)
|
||||||
try:
|
try:
|
||||||
return super().send(request, *args, **kwargs)
|
return super().send(request, *args, **kwargs)
|
||||||
except RequestException as e:
|
except RequestException as e:
|
||||||
self.adapter.log.debug(f"Request failed due to {e}, giving up.")
|
self.adapter.log.debug(f"Request failed due to {e}, giving up.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def extract_ip_from_certificate(self, certificate_pem: str) -> str:
|
|
||||||
"""
|
|
||||||
Extract IP address from the Subject Alternative Name (SAN) field of the certificate.
|
|
||||||
"""
|
|
||||||
certificate = x509.load_pem_x509_certificate(certificate_pem.encode(), default_backend())
|
|
||||||
try:
|
|
||||||
san = certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
|
||||||
except x509.ExtensionNotFound:
|
|
||||||
raise ValueError("No SAN extension found in certificate")
|
|
||||||
|
|
||||||
# Check if an IP address is listed in SAN and return the first one
|
|
||||||
for ip in san.value.get_values_for_type(x509.IPAddress):
|
|
||||||
return str(ip)
|
|
||||||
|
|
||||||
def replace_hostname_with_ip(self, url: str, ip_address: str) -> str:
|
|
||||||
"""
|
|
||||||
Replace the hostname in the URL with the provided IP address.
|
|
||||||
"""
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
# Reconstruct the URL with the IP address instead of the hostname
|
|
||||||
return urlunparse((
|
|
||||||
parsed_url.scheme,
|
|
||||||
f"{ip_address}:{parsed_url.port}",
|
|
||||||
parsed_url.path,
|
|
||||||
parsed_url.params,
|
|
||||||
parsed_url.query,
|
|
||||||
parsed_url.fragment
|
|
||||||
))
|
|
||||||
|
|
||||||
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
|
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
|
||||||
address = self._parse_url(url=request.url)
|
"""
|
||||||
certificate = self.get_or_refresh_certificate(address)
|
Intercept the request, prefetch the host's certificate,
|
||||||
self.adapter.accept_certificate(certificate=certificate)
|
and redirect the request to the certificate's resolved IP address.
|
||||||
url = self.replace_hostname_with_ip(
|
|
||||||
|
This embedded DNS resolution is necessary because the host's certificate
|
||||||
|
may contain an IP address in the Subject Alternative Name (SAN) field,
|
||||||
|
but the hostname in the URL may not resolve to the same IP address.
|
||||||
|
"""
|
||||||
|
|
||||||
|
address = self._resolve_address(url=request.url) # resolves dns
|
||||||
|
certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip
|
||||||
|
self.adapter.trust_certificate(certificate=certificate)
|
||||||
|
url = _replace_hostname_with_ip(
|
||||||
url=request.url,
|
url=request.url,
|
||||||
ip_address=gethostbyname(address.hostname)
|
ip_address=address.hostname
|
||||||
)
|
)
|
||||||
request.url = url
|
request.url = url # replace the hostname with the resolved IP address
|
||||||
try:
|
try:
|
||||||
return super().send(request, *args, **kwargs)
|
return super().send(request, *args, **kwargs)
|
||||||
except RequestException as e:
|
except RequestException as e:
|
||||||
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
|
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
|
||||||
return self.__retry_send(address, request, *args, **kwargs)
|
return self.__retry_send(address, request, *args, **kwargs)
|
||||||
|
|
||||||
def get_or_refresh_certificate(self, address: Address) -> Certificate:
|
def __get_or_refresh_certificate(self, address: Address) -> Certificate:
|
||||||
if self.cache.should_cache_now(address):
|
if self.cache.should_cache_now(address):
|
||||||
return self._refresh_certificate(address)
|
return self._refresh_certificate(address)
|
||||||
certificate = self.cache.get(address)
|
certificate = self.cache.get(address)
|
||||||
return certificate
|
return certificate
|
||||||
|
|
||||||
def _refresh_certificate(self, address: Address) -> Certificate:
|
def _refresh_certificate(self, address: Address) -> Certificate:
|
||||||
certificate = self.__fetch_server_cert(address)
|
certificate = _fetch_server_cert(address)
|
||||||
self.cache.set(address, certificate)
|
self.cache.set(address, certificate)
|
||||||
return certificate
|
return certificate
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __fetch_server_cert(address: Address) -> Certificate:
|
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
context.check_hostname = False
|
|
||||||
context.verify_mode = ssl.CERT_NONE
|
|
||||||
|
|
||||||
with socket.create_connection(address) as sock:
|
|
||||||
with context.wrap_socket(sock, server_hostname=address.hostname) as ssock:
|
|
||||||
sock.close() # close the insecure socket
|
|
||||||
certificate_bin = ssock.getpeercert(binary_form=True)
|
|
||||||
|
|
||||||
certificate = Certificate(ssl.DER_cert_to_PEM_cert(certificate_bin))
|
|
||||||
return certificate
|
|
||||||
|
|
|
@ -5,14 +5,14 @@ import pytest
|
||||||
from requests import Session, RequestException
|
from requests import Session, RequestException
|
||||||
|
|
||||||
from nucypher.utilities.certs import (
|
from nucypher.utilities.certs import (
|
||||||
InMemoryCertAdapter,
|
SelfSignedCertificateAdapter,
|
||||||
InMemoryCertSession,
|
P2PSession,
|
||||||
CertificateCache,
|
CertificateCache,
|
||||||
Address
|
Address
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define test URLs
|
# Define test URLs
|
||||||
VALID_URL = "https://example.com"
|
VALID_URL = "https://lynx.nucypher.network:9151/status"
|
||||||
INVALID_URL = "https://nonexistent-domain.com"
|
INVALID_URL = "https://nonexistent-domain.com"
|
||||||
|
|
||||||
MOCK_CERT = """-----BEGIN CERTIFICATE-----
|
MOCK_CERT = """-----BEGIN CERTIFICATE-----
|
||||||
|
@ -28,20 +28,20 @@ def cache():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def adapter(cache):
|
def adapter(cache):
|
||||||
_adapter = InMemoryCertAdapter()
|
_adapter = SelfSignedCertificateAdapter()
|
||||||
_adapter.cert_cache = cache
|
_adapter.cert_cache = cache
|
||||||
return _adapter
|
return _adapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(adapter):
|
def session(adapter):
|
||||||
s = InMemoryCertSession()
|
s = P2PSession()
|
||||||
s.adapter = adapter # Use the same adapter instance
|
s.adapter = adapter # Use the same adapter instance
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def test_init_adapter(cache, adapter):
|
def test_init_adapter(cache, adapter):
|
||||||
assert isinstance(adapter, InMemoryCertAdapter)
|
assert isinstance(adapter, SelfSignedCertificateAdapter)
|
||||||
|
|
||||||
|
|
||||||
def test_cert_cache_set_get():
|
def test_cert_cache_set_get():
|
||||||
|
@ -68,7 +68,7 @@ def test_cache_cert(cache):
|
||||||
|
|
||||||
|
|
||||||
def test_send_request(session, mocker):
|
def test_send_request(session, mocker):
|
||||||
mocker.patch.object(InMemoryCertAdapter, 'load_certificate')
|
mocker.patch.object(SelfSignedCertificateAdapter, 'trust_certificate')
|
||||||
mocked_refresh = mocker.patch.object(session, '_refresh_certificate', return_value=MOCK_CERT)
|
mocked_refresh = mocker.patch.object(session, '_refresh_certificate', return_value=MOCK_CERT)
|
||||||
mocker.patch.object(Session, 'send', return_value='response')
|
mocker.patch.object(Session, 'send', return_value='response')
|
||||||
response = session.send(mocker.Mock(url=VALID_URL))
|
response = session.send(mocker.Mock(url=VALID_URL))
|
||||||
|
@ -78,7 +78,7 @@ def test_send_request(session, mocker):
|
||||||
|
|
||||||
def test_https_request_with_cert_caching():
|
def test_https_request_with_cert_caching():
|
||||||
# Create a session with certificate caching
|
# Create a session with certificate caching
|
||||||
session = InMemoryCertSession()
|
session = P2PSession()
|
||||||
|
|
||||||
# Send a request (it should succeed)
|
# Send a request (it should succeed)
|
||||||
response = session.get(VALID_URL)
|
response = session.get(VALID_URL)
|
||||||
|
@ -91,14 +91,14 @@ def test_https_request_with_cert_caching():
|
||||||
|
|
||||||
def test_https_request_with_cert_refresh():
|
def test_https_request_with_cert_refresh():
|
||||||
# Create a session with certificate caching
|
# Create a session with certificate caching
|
||||||
session = InMemoryCertSession()
|
session = P2PSession()
|
||||||
|
|
||||||
# Send a request (it should succeed)
|
# Send a request (it should succeed)
|
||||||
response = session.get(VALID_URL)
|
response = session.get(VALID_URL)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Manually expire the cached certificate
|
# Manually expire the cached certificate
|
||||||
hostname, port = InMemoryCertSession._parse_url(VALID_URL)
|
hostname, port = P2PSession._resolve_address(VALID_URL)
|
||||||
session.cache._expirations[(hostname, port)] = 0
|
session.cache._expirations[(hostname, port)] = 0
|
||||||
|
|
||||||
# Send another request to the same URL (it should refresh the certificate)
|
# Send another request to the same URL (it should refresh the certificate)
|
||||||
|
|
Loading…
Reference in New Issue