mirror of https://github.com/nucypher/nucypher.git
Makes network middleware loally certificate agnostic
parent
13758d14a0
commit
1b6b61e334
|
@ -1213,22 +1213,9 @@ class Ursula(Teacher, Character, Operator):
|
|||
# Parse node URI
|
||||
host, port, staking_provider_address = parse_node_uri(seed_uri)
|
||||
|
||||
# Fetch the hosts TLS certificate and read the common name
|
||||
try:
|
||||
certificate, _filepath = network_middleware.client.get_certificate(
|
||||
host=host, port=port
|
||||
)
|
||||
except NodeSeemsToBeDown as e:
|
||||
e.args += (f"While trying to load seednode {seed_uri}",)
|
||||
e.crash_right_now = True
|
||||
raise
|
||||
real_host = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[
|
||||
0
|
||||
].value
|
||||
|
||||
# Load the host as a potential seed node
|
||||
potential_seed_node = cls.from_rest_url(
|
||||
host=real_host,
|
||||
host=host,
|
||||
port=port,
|
||||
network_middleware=network_middleware,
|
||||
)
|
||||
|
|
|
@ -1,24 +1,15 @@
|
|||
import os
|
||||
import socket
|
||||
import ssl
|
||||
from http import HTTPStatus
|
||||
<<<<<<< HEAD
|
||||
from typing import Tuple, Union
|
||||
=======
|
||||
from typing import Optional, Sequence
|
||||
>>>>>>> 8233d2ce2 (removes certificate filepath handling)
|
||||
from typing import Tuple, Union
|
||||
|
||||
import time
|
||||
from constant_sorrow.constants import EXEMPT_FROM_VERIFICATION
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from nucypher_core import FleetStateChecksum, MetadataRequest, NodeMetadata
|
||||
|
||||
from nucypher import characters
|
||||
from nucypher.blockchain.eth.registry import ContractRegistry
|
||||
from nucypher.config.storages import NodeStorage
|
||||
from nucypher.utilities.logging import Logger
|
||||
from nucypher.utilities.certs import InMemoryCertSession
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
SSL_LOGGER = Logger('ssl-middleware')
|
||||
EXEMPT_FROM_VERIFICATION.bool_value(False)
|
||||
|
@ -43,7 +34,6 @@ class NucypherMiddlewareClient:
|
|||
self,
|
||||
eth_endpoint: Optional[str],
|
||||
registry: Optional[ContractRegistry] = None,
|
||||
storage: Optional[NodeStorage] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -52,39 +42,6 @@ class NucypherMiddlewareClient:
|
|||
|
||||
self.registry = registry
|
||||
self.eth_endpoint = eth_endpoint
|
||||
self.storage = storage or NodeStorage() # for certificate storage
|
||||
|
||||
def get_certificate(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
timeout=MIDDLEWARE_DEFAULT_CERTIFICATE_TIMEOUT,
|
||||
retry_attempts: int = 3,
|
||||
retry_rate: int = 2,
|
||||
current_attempt: int = 0,
|
||||
):
|
||||
socket.setdefaulttimeout(timeout) # Set Socket Timeout
|
||||
|
||||
try:
|
||||
SSL_LOGGER.debug(f"Fetching {host}:{port} TLS certificate")
|
||||
certificate_pem = ssl.get_server_certificate(addr=(host, port))
|
||||
certificate = ssl.PEM_cert_to_DER_cert(certificate_pem)
|
||||
|
||||
except socket.timeout:
|
||||
if current_attempt == retry_attempts:
|
||||
message = f"No Response from {host}:{port} after {retry_attempts} attempts"
|
||||
SSL_LOGGER.info(message)
|
||||
raise ConnectionRefusedError("No response from {}:{}".format(host, port))
|
||||
SSL_LOGGER.info(f"No Response from {host}:{port}. Retrying in {retry_rate} seconds...")
|
||||
time.sleep(retry_rate)
|
||||
return self.get_certificate(host, port, timeout, retry_attempts, retry_rate, current_attempt + 1)
|
||||
|
||||
except OSError:
|
||||
raise # TODO: #1835
|
||||
|
||||
certificate = x509.load_der_x509_certificate(certificate, backend=default_backend())
|
||||
filepath = self.storage.store_node_certificate(certificate=certificate, port=port)
|
||||
return certificate, filepath
|
||||
|
||||
@staticmethod
|
||||
def response_cleaner(response):
|
||||
|
|
|
@ -1,53 +1,112 @@
|
|||
import socket
|
||||
import ssl
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from requests import Session
|
||||
import time
|
||||
from requests import Session, Response
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import RequestException
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
|
||||
def parse_url(url) -> Tuple[str, int]:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname or ''
|
||||
port = parsed.port or 443
|
||||
return hostname, port
|
||||
|
||||
|
||||
class InMemoryCertAdapter(HTTPAdapter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Transport adapter that uses a cached certificate for HTTPS requests"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_duration: int = 3600,
|
||||
refresh_interval: int = 600,
|
||||
*args, **kwargs
|
||||
):
|
||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
self.cert_cache = {}
|
||||
self.cache_expiry = {}
|
||||
self.cache_duration = cache_duration
|
||||
self.cache_refresh_interval = refresh_interval
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_poolmanager(self, *args, **kwargs):
|
||||
def init_poolmanager(self, *args, **kwargs) -> None:
|
||||
"""Override the default poolmanager to use the local SSL context"""
|
||||
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
|
||||
|
||||
def update_cert(self, cert_pem):
|
||||
def set_active_cert(self, cert_pem: str) -> None:
|
||||
"""Set the active certificate for the SSL context"""
|
||||
self.ssl_context.load_verify_locations(cadata=cert_pem)
|
||||
|
||||
def get_cached_cert(self, hostname: str, port: int) -> str:
|
||||
return self.cert_cache.get((hostname, port))
|
||||
|
||||
def set_cached_cert(self, hostname: str, port: int, cert_pem: str) -> None:
|
||||
self.cert_cache[(hostname, port)] = cert_pem
|
||||
self.cache_expiry[(hostname, port)] = time.time() + self.cache_duration
|
||||
|
||||
def is_cert_expired(self, hostname: str, port: int) -> bool:
|
||||
return ((hostname, port) in self.cache_expiry
|
||||
and time.time() > self.cache_expiry[(hostname, port)])
|
||||
|
||||
def should_cache_now(self, hostname: str, port: int) -> bool:
|
||||
return (
|
||||
(hostname, port) not in self.cache_expiry
|
||||
or time.time()
|
||||
> self.cache_expiry[(hostname, port)]
|
||||
- self.cache_refresh_interval
|
||||
)
|
||||
|
||||
|
||||
class InMemoryCertSession(Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter = InMemoryCertAdapter()
|
||||
self.mount("https://", self.adapter)
|
||||
|
||||
def send(self, *args, **kwargs):
|
||||
def send(self, *args, **kwargs) -> Response:
|
||||
"""
|
||||
Override the send to check if the certificate should be refreshed
|
||||
and to refresh it if needed before sending the request.
|
||||
"""
|
||||
|
||||
# Parse the URL to extract hostname and port
|
||||
url = kwargs.get('url', args[0].url)
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname or ''
|
||||
port = parsed.port or 443
|
||||
hostname, port = parse_url(url=url)
|
||||
|
||||
# Fetch and update the certificate
|
||||
cert_pem = self.fetch_server_cert(hostname, port)
|
||||
self.adapter.update_cert(cert_pem)
|
||||
if self.adapter.should_cache_now(hostname=hostname, port=port):
|
||||
self.refresh_certificate(hostname=hostname, port=port)
|
||||
|
||||
# Perform the actual request
|
||||
return super().send(*args, **kwargs)
|
||||
try:
|
||||
# Perform the actual request
|
||||
response = super().send(*args, **kwargs)
|
||||
except RequestException:
|
||||
# reconnect and retry once
|
||||
self.refresh_certificate(hostname=hostname, port=port)
|
||||
response = super().send(*args, **kwargs)
|
||||
|
||||
return response
|
||||
|
||||
def refresh_certificate(self, hostname: str, port: int) -> None:
|
||||
cert_pem = self.__fetch_server_cert(hostname, port)
|
||||
self.adapter.set_cached_cert(
|
||||
hostname=hostname,
|
||||
port=port,
|
||||
cert_pem=cert_pem
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fetch_server_cert(hostname, port):
|
||||
def __fetch_server_cert(hostname: str, port: int) -> str:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
with socket.create_connection((hostname, port)) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||
cert_bin = ssock.getpeercert(True)
|
||||
with context.wrap_socket(sock, server_hostname=hostname) as wrapped_sock:
|
||||
cert_bin = wrapped_sock.getpeercert(binary_form=True)
|
||||
cert_pem = ssl.DER_cert_to_PEM_cert(cert_bin)
|
||||
return cert_pem
|
||||
|
||||
|
|
|
@ -103,14 +103,6 @@ def temp_dir_path():
|
|||
yield Path(temp_dir.name)
|
||||
temp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def certificates_tempdir():
|
||||
custom_filepath = '/tmp/nucypher-test-certificates-'
|
||||
cert_tmpdir = tempfile.TemporaryDirectory(prefix=custom_filepath)
|
||||
yield Path(cert_tmpdir.name)
|
||||
cert_tmpdir.cleanup()
|
||||
|
||||
#
|
||||
# Accounts
|
||||
#
|
||||
|
@ -389,7 +381,6 @@ def get_random_checksum_address():
|
|||
@pytest.fixture(scope="module")
|
||||
def fleet_of_highperf_mocked_ursulas(ursula_test_config, request, testerchain):
|
||||
mocks = (
|
||||
mock_cert_storage,
|
||||
mock_cert_loading,
|
||||
mock_rest_app_creation,
|
||||
mock_cert_generation,
|
||||
|
@ -452,7 +443,7 @@ def highperf_mocked_alice(
|
|||
reload_metadata=False,
|
||||
)
|
||||
|
||||
with mock_cert_storage, mock_verify_node, mock_message_verification, mock_keep_learning:
|
||||
with mock_verify_node, mock_message_verification, mock_keep_learning:
|
||||
alice = config.produce(known_nodes=list(fleet_of_highperf_mocked_ursulas)[:1])
|
||||
yield alice
|
||||
# TODO: Where does this really, truly belong?
|
||||
|
@ -473,7 +464,7 @@ def highperf_mocked_bob(fleet_of_highperf_mocked_ursulas):
|
|||
reload_metadata=False,
|
||||
)
|
||||
|
||||
with mock_cert_storage, mock_verify_node, mock_record_fleet_state, mock_keep_learning:
|
||||
with mock_verify_node, mock_record_fleet_state, mock_keep_learning:
|
||||
bob = config.produce(known_nodes=list(fleet_of_highperf_mocked_ursulas)[:1])
|
||||
yield bob
|
||||
bob._learning_task.stop()
|
||||
|
|
|
@ -89,7 +89,6 @@ def mock_requests(mocker):
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_client(mocker):
|
||||
cert, pk = generate_self_signed_certificate(host=MOCK_IP_ADDRESS)
|
||||
mocker.patch.object(NucypherMiddlewareClient, 'get_certificate', return_value=(cert, Path()))
|
||||
yield mocker.patch.object(NucypherMiddlewareClient, 'invoke_method', return_value=Dummy.GoodResponse)
|
||||
|
||||
|
||||
|
|
|
@ -68,10 +68,6 @@ class _TestMiddlewareClient(NucypherMiddlewareClient):
|
|||
def clean_params(self, request_kwargs):
|
||||
request_kwargs["query_string"] = request_kwargs.pop("params", {})
|
||||
|
||||
def get_certificate(self, port, *args, **kwargs):
|
||||
ursula = self._get_ursula_by_port(port)
|
||||
return ursula.certificate, Path()
|
||||
|
||||
|
||||
class MockRestMiddleware(RestMiddleware):
|
||||
_ursulas = None
|
||||
|
|
Loading…
Reference in New Issue