Add support for family to aiohttp session helper (#102702)

pull/98003/head^2
J. Nick Koston 2023-10-24 18:40:39 -05:00 committed by GitHub
parent a691bd26cf
commit f91583a0fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 49 deletions

View File

@ -7,7 +7,7 @@ from contextlib import suppress
from ssl import SSLContext from ssl import SSLContext
import sys import sys
from types import MappingProxyType from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -29,9 +29,8 @@ if TYPE_CHECKING:
DATA_CONNECTOR = "aiohttp_connector" DATA_CONNECTOR = "aiohttp_connector"
DATA_CONNECTOR_NOTVERIFY = "aiohttp_connector_notverify"
DATA_CLIENTSESSION = "aiohttp_clientsession" DATA_CLIENTSESSION = "aiohttp_clientsession"
DATA_CLIENTSESSION_NOTVERIFY = "aiohttp_clientsession_notverify"
SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format( SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format(
APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info
) )
@ -88,22 +87,31 @@ class HassClientResponse(aiohttp.ClientResponse):
@callback @callback
@bind_hass @bind_hass
def async_get_clientsession( def async_get_clientsession(
hass: HomeAssistant, verify_ssl: bool = True hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.ClientSession: ) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession. """Return default aiohttp ClientSession.
This method must be run in the event loop. This method must be run in the event loop.
""" """
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY session_key = _make_key(verify_ssl, family)
if DATA_CLIENTSESSION not in hass.data:
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
hass.data[DATA_CLIENTSESSION] = sessions
else:
sessions = hass.data[DATA_CLIENTSESSION]
if key not in hass.data: if session_key not in sessions:
hass.data[key] = _async_create_clientsession( session = _async_create_clientsession(
hass, hass,
verify_ssl, verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown, auto_cleanup_method=_async_register_default_clientsession_shutdown,
family=family,
) )
sessions[session_key] = session
else:
session = sessions[session_key]
return cast(aiohttp.ClientSession, hass.data[key]) return session
@callback @callback
@ -112,6 +120,7 @@ def async_create_clientsession(
hass: HomeAssistant, hass: HomeAssistant,
verify_ssl: bool = True, verify_ssl: bool = True,
auto_cleanup: bool = True, auto_cleanup: bool = True,
family: int = 0,
**kwargs: Any, **kwargs: Any,
) -> aiohttp.ClientSession: ) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies. """Create a new ClientSession with kwargs, i.e. for cookies.
@ -131,6 +140,7 @@ def async_create_clientsession(
hass, hass,
verify_ssl, verify_ssl,
auto_cleanup_method=auto_cleanup_method, auto_cleanup_method=auto_cleanup_method,
family=family,
**kwargs, **kwargs,
) )
@ -143,11 +153,12 @@ def _async_create_clientsession(
verify_ssl: bool = True, verify_ssl: bool = True,
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None] auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None, | None = None,
family: int = 0,
**kwargs: Any, **kwargs: Any,
) -> aiohttp.ClientSession: ) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies.""" """Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession( clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl), connector=_async_get_connector(hass, verify_ssl, family),
json_serialize=json_dumps, json_serialize=json_dumps,
response_class=HassClientResponse, response_class=HassClientResponse,
**kwargs, **kwargs,
@ -275,18 +286,29 @@ def _async_register_default_clientsession_shutdown(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)
@callback
def _make_key(verify_ssl: bool = True, family: int = 0) -> tuple[bool, int]:
"""Make a key for connector or session pool."""
return (verify_ssl, family)
@callback @callback
def _async_get_connector( def _async_get_connector(
hass: HomeAssistant, verify_ssl: bool = True hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.BaseConnector: ) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp. """Return the connector pool for aiohttp.
This method must be run in the event loop. This method must be run in the event loop.
""" """
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY connector_key = _make_key(verify_ssl, family)
if DATA_CONNECTOR not in hass.data:
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
hass.data[DATA_CONNECTOR] = connectors
else:
connectors = hass.data[DATA_CONNECTOR]
if key in hass.data: if connector_key in connectors:
return cast(aiohttp.BaseConnector, hass.data[key]) return connectors[connector_key]
if verify_ssl: if verify_ssl:
ssl_context: bool | SSLContext = ssl_util.get_default_context() ssl_context: bool | SSLContext = ssl_util.get_default_context()
@ -294,12 +316,13 @@ def _async_get_connector(
ssl_context = ssl_util.get_default_no_verify_context() ssl_context = ssl_util.get_default_no_verify_context()
connector = aiohttp.TCPConnector( connector = aiohttp.TCPConnector(
family=family,
enable_cleanup_closed=ENABLE_CLEANUP_CLOSED, enable_cleanup_closed=ENABLE_CLEANUP_CLOSED,
ssl=ssl_context, ssl=ssl_context,
limit=MAXIMUM_CONNECTIONS, limit=MAXIMUM_CONNECTIONS,
limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST, limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
) )
hass.data[key] = connector connectors[connector_key] = connector
async def _async_close_connector(event: Event) -> None: async def _async_close_connector(event: Event) -> None:
"""Close connector pool.""" """Close connector pool."""

View File

@ -16,7 +16,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION, _make_key
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@ -483,7 +483,7 @@ async def test_media_image_proxy(
def detach(self): def detach(self):
"""Test websession detach.""" """Test websession detach."""
hass.data[DATA_CLIENTSESSION] = MockWebsession() hass.data[DATA_CLIENTSESSION] = {_make_key(): MockWebsession()}
state = hass.states.get(TEST_ENTITY_ID) state = hass.states.get(TEST_ENTITY_ID)
assert state.state == STATE_PLAYING assert state.state == STATE_PLAYING

View File

@ -52,26 +52,53 @@ def camera_client_fixture(hass, hass_client):
async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None: async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl.""" """Test init clientsession with ssl."""
client.async_get_clientsession(hass) client.async_get_clientsession(hass)
verify_ssl = True
family = 0
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)
async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None: async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession without ssl.""" """Test init clientsession without ssl."""
client.async_get_clientsession(hass, verify_ssl=False) client.async_get_clientsession(hass, verify_ssl=False)
verify_ssl = False
family = 0
assert isinstance( client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession assert isinstance(client_session, aiohttp.ClientSession)
) connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector) assert isinstance(connector, aiohttp.TCPConnector)
@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession combinations."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)
async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None: async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None:
"""Test create clientsession with ssl.""" """Test create clientsession with ssl."""
session = client.async_create_clientsession(hass, cookies={"bla": True}) session = client.async_create_clientsession(hass, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession) assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
verify_ssl = True
family = 0
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)
async def test_create_clientsession_without_ssl_and_cookies( async def test_create_clientsession_without_ssl_and_cookies(
@ -80,46 +107,53 @@ async def test_create_clientsession_without_ssl_and_cookies(
"""Test create clientsession without ssl.""" """Test create clientsession without ssl."""
session = client.async_create_clientsession(hass, False, cookies={"bla": True}) session = client.async_create_clientsession(hass, False, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession) assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
verify_ssl = False
family = 0
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)
async def test_get_clientsession_cleanup(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"""Test init clientsession with ssl.""" ("verify_ssl", "expected_family"),
client.async_get_clientsession(hass) [(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession_cleanup(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession cleanup."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE) hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done() await hass.async_block_till_done()
assert hass.data[client.DATA_CLIENTSESSION].closed assert client_session.closed
assert hass.data[client.DATA_CONNECTOR].closed assert connector.closed
async def test_get_clientsession_cleanup_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
assert isinstance(
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()
assert hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed
assert hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed
async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None: async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
"""Test closing clientsession does not work.""" """Test closing clientsession does not work."""
verify_ssl = True
family = 0
with patch("aiohttp.ClientSession.close") as mock_close: with patch("aiohttp.ClientSession.close") as mock_close:
session = client.async_get_clientsession(hass) session = client.async_get_clientsession(hass)
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) assert isinstance(
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)],
aiohttp.ClientSession,
)
assert isinstance(
hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector
)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await session.close() await session.close()