mirror of https://github.com/nucypher/nucypher.git
Don't reuse the same SSL context. Use a new SSL context per connection so that the ca data can be updated accordingly.
parent
c4019d165e
commit
755cbd1511
|
@ -68,27 +68,44 @@ class CertificateCache:
|
|||
)
|
||||
|
||||
|
||||
class SelfSignedPoolManager(PoolManager):
|
||||
def __init__(self, certificate_cache: CertificateCache, *args, **kwargs):
|
||||
self.certificate_cache = certificate_cache
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def connection_from_url(self, url, pool_kwargs=None):
|
||||
if not pool_kwargs:
|
||||
pool_kwargs = {}
|
||||
ssl_context = pool_kwargs.get("ssl_context")
|
||||
if not ssl_context:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
ssl_context.check_hostname = False
|
||||
pool_kwargs["ssl_context"] = ssl_context
|
||||
|
||||
parsed = urlparse(url)
|
||||
host, port = parsed.hostname, parsed.port
|
||||
cached_certificate = self.certificate_cache.get(Address(host, port))
|
||||
if cached_certificate:
|
||||
ssl_context.load_verify_locations(cadata=cached_certificate)
|
||||
|
||||
return super().connection_from_url(url, pool_kwargs=pool_kwargs)
|
||||
|
||||
|
||||
class SelfSignedCertificateAdapter(HTTPAdapter):
|
||||
"""An adapter that verifies self-signed certificates in memory only."""
|
||||
|
||||
log = logging.Logger(__name__)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.ssl_context.check_hostname = False
|
||||
def __init__(self, certificate_cache: CertificateCache, *args, **kwargs):
|
||||
self.certificate_cache = certificate_cache
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_poolmanager(self, *args, **kwargs) -> None:
|
||||
"""Override the default poolmanager to use the local SSL context."""
|
||||
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
|
||||
|
||||
def trust_certificate(self, certificate: Certificate) -> None:
|
||||
"""Accept the given certificate as trusted."""
|
||||
try:
|
||||
self.ssl_context.load_verify_locations(cadata=certificate)
|
||||
except ssl.SSLError as e:
|
||||
self.log.debug(f"Failed to load certificate {e}.")
|
||||
"""Override the default poolmanager to use the certificate cache."""
|
||||
self.poolmanager = SelfSignedPoolManager(
|
||||
self.certificate_cache, *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class P2PSession(Session):
|
||||
|
@ -97,8 +114,8 @@ class P2PSession(Session):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter = SelfSignedCertificateAdapter()
|
||||
self.cache = CertificateCache()
|
||||
self.certificate_cache = CertificateCache()
|
||||
self.adapter = SelfSignedCertificateAdapter(self.certificate_cache)
|
||||
self.mount("https://", self.adapter)
|
||||
|
||||
@classmethod
|
||||
|
@ -110,8 +127,7 @@ class P2PSession(Session):
|
|||
return Address(hostname, parsed.port or cls._DEFAULT_PORT)
|
||||
|
||||
def __retry_send(self, address, request, *args, **kwargs) -> Response:
|
||||
certificate = self._refresh_certificate(address)
|
||||
self.adapter.trust_certificate(certificate=certificate)
|
||||
self._refresh_certificate(address)
|
||||
try:
|
||||
return super().send(request, *args, **kwargs)
|
||||
except RequestException as e:
|
||||
|
@ -129,8 +145,7 @@ class P2PSession(Session):
|
|||
"""
|
||||
|
||||
address = self._resolve_address(url=request.url) # resolves dns
|
||||
certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip
|
||||
self.adapter.trust_certificate(certificate=certificate)
|
||||
self.__ensure_certificate_cached_and_uptodate(address) # cache by resolved ip
|
||||
url = _replace_with_resolved_address(url=request.url, resolved_address=address)
|
||||
request.url = url # replace the hostname with the resolved IP address
|
||||
try:
|
||||
|
@ -139,13 +154,13 @@ class P2PSession(Session):
|
|||
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
|
||||
return self.__retry_send(address, request, *args, **kwargs)
|
||||
|
||||
def __get_or_refresh_certificate(self, address: Address) -> Certificate:
|
||||
if self.cache.should_cache_now(address):
|
||||
def __ensure_certificate_cached_and_uptodate(self, address: Address) -> Certificate:
|
||||
if self.certificate_cache.should_cache_now(address):
|
||||
return self._refresh_certificate(address)
|
||||
certificate = self.cache.get(address)
|
||||
certificate = self.certificate_cache.get(address)
|
||||
return certificate
|
||||
|
||||
def _refresh_certificate(self, address: Address) -> Certificate:
|
||||
certificate = _fetch_server_cert(address)
|
||||
self.cache.set(address, certificate)
|
||||
self.certificate_cache.set(address, certificate)
|
||||
return certificate
|
||||
|
|
|
@ -41,8 +41,7 @@ def cache():
|
|||
|
||||
@pytest.fixture
|
||||
def adapter(cache):
|
||||
_adapter = SelfSignedCertificateAdapter()
|
||||
_adapter.cert_cache = cache
|
||||
_adapter = SelfSignedCertificateAdapter(certificate_cache=cache)
|
||||
return _adapter
|
||||
|
||||
|
||||
|
@ -81,7 +80,6 @@ def test_cache_cert(cache):
|
|||
|
||||
|
||||
def test_send_request(session, mocker):
|
||||
mocker.patch.object(SelfSignedCertificateAdapter, "trust_certificate")
|
||||
mocked_refresh = mocker.patch.object(
|
||||
session, "_refresh_certificate", return_value=MOCK_CERT
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue