Don't reuse the same SSL context. Use a new SSL context per connection so that the ca data can be updated accordingly.

remotes/origin/v7.4.x
derekpierre 2024-02-09 06:45:07 -05:00 committed by Derek Pierre
parent c4019d165e
commit 755cbd1511
2 changed files with 39 additions and 26 deletions

View File

@ -68,27 +68,44 @@ class CertificateCache:
)
class SelfSignedPoolManager(PoolManager):
def __init__(self, certificate_cache: CertificateCache, *args, **kwargs):
self.certificate_cache = certificate_cache
super().__init__(*args, **kwargs)
def connection_from_url(self, url, pool_kwargs=None):
if not pool_kwargs:
pool_kwargs = {}
ssl_context = pool_kwargs.get("ssl_context")
if not ssl_context:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.verify_mode = ssl.CERT_REQUIRED
ssl_context.check_hostname = False
pool_kwargs["ssl_context"] = ssl_context
parsed = urlparse(url)
host, port = parsed.hostname, parsed.port
cached_certificate = self.certificate_cache.get(Address(host, port))
if cached_certificate:
ssl_context.load_verify_locations(cadata=cached_certificate)
return super().connection_from_url(url, pool_kwargs=pool_kwargs)
class SelfSignedCertificateAdapter(HTTPAdapter):
"""An adapter that verifies self-signed certificates in memory only."""
log = logging.Logger(__name__)
def __init__(self, *args, **kwargs):
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.check_hostname = False
def __init__(self, certificate_cache: CertificateCache, *args, **kwargs):
self.certificate_cache = certificate_cache
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs) -> None:
"""Override the default poolmanager to use the local SSL context."""
self.poolmanager = PoolManager(*args, ssl_context=self.ssl_context, **kwargs)
def trust_certificate(self, certificate: Certificate) -> None:
"""Accept the given certificate as trusted."""
try:
self.ssl_context.load_verify_locations(cadata=certificate)
except ssl.SSLError as e:
self.log.debug(f"Failed to load certificate {e}.")
"""Override the default poolmanager to use the certificate cache."""
self.poolmanager = SelfSignedPoolManager(
self.certificate_cache, *args, **kwargs
)
class P2PSession(Session):
@ -97,8 +114,8 @@ class P2PSession(Session):
def __init__(self):
super().__init__()
self.adapter = SelfSignedCertificateAdapter()
self.cache = CertificateCache()
self.certificate_cache = CertificateCache()
self.adapter = SelfSignedCertificateAdapter(self.certificate_cache)
self.mount("https://", self.adapter)
@classmethod
@ -110,8 +127,7 @@ class P2PSession(Session):
return Address(hostname, parsed.port or cls._DEFAULT_PORT)
def __retry_send(self, address, request, *args, **kwargs) -> Response:
certificate = self._refresh_certificate(address)
self.adapter.trust_certificate(certificate=certificate)
self._refresh_certificate(address)
try:
return super().send(request, *args, **kwargs)
except RequestException as e:
@ -129,8 +145,7 @@ class P2PSession(Session):
"""
address = self._resolve_address(url=request.url) # resolves dns
certificate = self.__get_or_refresh_certificate(address) # cache by resolved ip
self.adapter.trust_certificate(certificate=certificate)
self.__ensure_certificate_cached_and_uptodate(address) # cache by resolved ip
url = _replace_with_resolved_address(url=request.url, resolved_address=address)
request.url = url # replace the hostname with the resolved IP address
try:
@ -139,13 +154,13 @@ class P2PSession(Session):
self.adapter.log.debug(f"Request failed due to {e}, retrying...")
return self.__retry_send(address, request, *args, **kwargs)
def __get_or_refresh_certificate(self, address: Address) -> Certificate:
if self.cache.should_cache_now(address):
def __ensure_certificate_cached_and_uptodate(self, address: Address) -> Certificate:
if self.certificate_cache.should_cache_now(address):
return self._refresh_certificate(address)
certificate = self.cache.get(address)
certificate = self.certificate_cache.get(address)
return certificate
def _refresh_certificate(self, address: Address) -> Certificate:
certificate = _fetch_server_cert(address)
self.cache.set(address, certificate)
self.certificate_cache.set(address, certificate)
return certificate

View File

@ -41,8 +41,7 @@ def cache():
@pytest.fixture
def adapter(cache):
_adapter = SelfSignedCertificateAdapter()
_adapter.cert_cache = cache
_adapter = SelfSignedCertificateAdapter(certificate_cache=cache)
return _adapter
@ -81,7 +80,6 @@ def test_cache_cert(cache):
def test_send_request(session, mocker):
mocker.patch.object(SelfSignedCertificateAdapter, "trust_certificate")
mocked_refresh = mocker.patch.object(
session, "_refresh_certificate", return_value=MOCK_CERT
)