Add support for family to aiohttp session helper (#102702)
parent
a691bd26cf
commit
f91583a0fc
|
@ -7,7 +7,7 @@ from contextlib import suppress
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
import sys
|
import sys
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
@ -29,9 +29,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
DATA_CONNECTOR = "aiohttp_connector"
|
DATA_CONNECTOR = "aiohttp_connector"
|
||||||
DATA_CONNECTOR_NOTVERIFY = "aiohttp_connector_notverify"
|
|
||||||
DATA_CLIENTSESSION = "aiohttp_clientsession"
|
DATA_CLIENTSESSION = "aiohttp_clientsession"
|
||||||
DATA_CLIENTSESSION_NOTVERIFY = "aiohttp_clientsession_notverify"
|
|
||||||
SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format(
|
SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format(
|
||||||
APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info
|
APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info
|
||||||
)
|
)
|
||||||
|
@ -88,22 +87,31 @@ class HassClientResponse(aiohttp.ClientResponse):
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_get_clientsession(
|
def async_get_clientsession(
|
||||||
hass: HomeAssistant, verify_ssl: bool = True
|
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
|
||||||
) -> aiohttp.ClientSession:
|
) -> aiohttp.ClientSession:
|
||||||
"""Return default aiohttp ClientSession.
|
"""Return default aiohttp ClientSession.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
|
session_key = _make_key(verify_ssl, family)
|
||||||
|
if DATA_CLIENTSESSION not in hass.data:
|
||||||
|
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
|
||||||
|
hass.data[DATA_CLIENTSESSION] = sessions
|
||||||
|
else:
|
||||||
|
sessions = hass.data[DATA_CLIENTSESSION]
|
||||||
|
|
||||||
if key not in hass.data:
|
if session_key not in sessions:
|
||||||
hass.data[key] = _async_create_clientsession(
|
session = _async_create_clientsession(
|
||||||
hass,
|
hass,
|
||||||
verify_ssl,
|
verify_ssl,
|
||||||
auto_cleanup_method=_async_register_default_clientsession_shutdown,
|
auto_cleanup_method=_async_register_default_clientsession_shutdown,
|
||||||
|
family=family,
|
||||||
)
|
)
|
||||||
|
sessions[session_key] = session
|
||||||
|
else:
|
||||||
|
session = sessions[session_key]
|
||||||
|
|
||||||
return cast(aiohttp.ClientSession, hass.data[key])
|
return session
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -112,6 +120,7 @@ def async_create_clientsession(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
auto_cleanup: bool = True,
|
auto_cleanup: bool = True,
|
||||||
|
family: int = 0,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> aiohttp.ClientSession:
|
) -> aiohttp.ClientSession:
|
||||||
"""Create a new ClientSession with kwargs, i.e. for cookies.
|
"""Create a new ClientSession with kwargs, i.e. for cookies.
|
||||||
|
@ -131,6 +140,7 @@ def async_create_clientsession(
|
||||||
hass,
|
hass,
|
||||||
verify_ssl,
|
verify_ssl,
|
||||||
auto_cleanup_method=auto_cleanup_method,
|
auto_cleanup_method=auto_cleanup_method,
|
||||||
|
family=family,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -143,11 +153,12 @@ def _async_create_clientsession(
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
|
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
|
family: int = 0,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> aiohttp.ClientSession:
|
) -> aiohttp.ClientSession:
|
||||||
"""Create a new ClientSession with kwargs, i.e. for cookies."""
|
"""Create a new ClientSession with kwargs, i.e. for cookies."""
|
||||||
clientsession = aiohttp.ClientSession(
|
clientsession = aiohttp.ClientSession(
|
||||||
connector=_async_get_connector(hass, verify_ssl),
|
connector=_async_get_connector(hass, verify_ssl, family),
|
||||||
json_serialize=json_dumps,
|
json_serialize=json_dumps,
|
||||||
response_class=HassClientResponse,
|
response_class=HassClientResponse,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -275,18 +286,29 @@ def _async_register_default_clientsession_shutdown(
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _make_key(verify_ssl: bool = True, family: int = 0) -> tuple[bool, int]:
|
||||||
|
"""Make a key for connector or session pool."""
|
||||||
|
return (verify_ssl, family)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_connector(
|
def _async_get_connector(
|
||||||
hass: HomeAssistant, verify_ssl: bool = True
|
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
|
||||||
) -> aiohttp.BaseConnector:
|
) -> aiohttp.BaseConnector:
|
||||||
"""Return the connector pool for aiohttp.
|
"""Return the connector pool for aiohttp.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY
|
connector_key = _make_key(verify_ssl, family)
|
||||||
|
if DATA_CONNECTOR not in hass.data:
|
||||||
|
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
|
||||||
|
hass.data[DATA_CONNECTOR] = connectors
|
||||||
|
else:
|
||||||
|
connectors = hass.data[DATA_CONNECTOR]
|
||||||
|
|
||||||
if key in hass.data:
|
if connector_key in connectors:
|
||||||
return cast(aiohttp.BaseConnector, hass.data[key])
|
return connectors[connector_key]
|
||||||
|
|
||||||
if verify_ssl:
|
if verify_ssl:
|
||||||
ssl_context: bool | SSLContext = ssl_util.get_default_context()
|
ssl_context: bool | SSLContext = ssl_util.get_default_context()
|
||||||
|
@ -294,12 +316,13 @@ def _async_get_connector(
|
||||||
ssl_context = ssl_util.get_default_no_verify_context()
|
ssl_context = ssl_util.get_default_no_verify_context()
|
||||||
|
|
||||||
connector = aiohttp.TCPConnector(
|
connector = aiohttp.TCPConnector(
|
||||||
|
family=family,
|
||||||
enable_cleanup_closed=ENABLE_CLEANUP_CLOSED,
|
enable_cleanup_closed=ENABLE_CLEANUP_CLOSED,
|
||||||
ssl=ssl_context,
|
ssl=ssl_context,
|
||||||
limit=MAXIMUM_CONNECTIONS,
|
limit=MAXIMUM_CONNECTIONS,
|
||||||
limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
|
limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
|
||||||
)
|
)
|
||||||
hass.data[key] = connector
|
connectors[connector_key] = connector
|
||||||
|
|
||||||
async def _async_close_connector(event: Event) -> None:
|
async def _async_close_connector(event: Event) -> None:
|
||||||
"""Close connector pool."""
|
"""Close connector pool."""
|
||||||
|
|
|
@ -16,7 +16,7 @@ from homeassistant.const import (
|
||||||
Platform,
|
Platform,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION
|
from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION, _make_key
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
@ -483,7 +483,7 @@ async def test_media_image_proxy(
|
||||||
def detach(self):
|
def detach(self):
|
||||||
"""Test websession detach."""
|
"""Test websession detach."""
|
||||||
|
|
||||||
hass.data[DATA_CLIENTSESSION] = MockWebsession()
|
hass.data[DATA_CLIENTSESSION] = {_make_key(): MockWebsession()}
|
||||||
|
|
||||||
state = hass.states.get(TEST_ENTITY_ID)
|
state = hass.states.get(TEST_ENTITY_ID)
|
||||||
assert state.state == STATE_PLAYING
|
assert state.state == STATE_PLAYING
|
||||||
|
|
|
@ -52,26 +52,53 @@ def camera_client_fixture(hass, hass_client):
|
||||||
async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
|
async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
|
||||||
"""Test init clientsession with ssl."""
|
"""Test init clientsession with ssl."""
|
||||||
client.async_get_clientsession(hass)
|
client.async_get_clientsession(hass)
|
||||||
|
verify_ssl = True
|
||||||
|
family = 0
|
||||||
|
|
||||||
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
|
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
|
assert isinstance(client_session, aiohttp.ClientSession)
|
||||||
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
|
||||||
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
|
async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
|
||||||
"""Test init clientsession without ssl."""
|
"""Test init clientsession without ssl."""
|
||||||
client.async_get_clientsession(hass, verify_ssl=False)
|
client.async_get_clientsession(hass, verify_ssl=False)
|
||||||
|
verify_ssl = False
|
||||||
|
family = 0
|
||||||
|
|
||||||
assert isinstance(
|
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
|
||||||
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
|
assert isinstance(client_session, aiohttp.ClientSession)
|
||||||
)
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("verify_ssl", "expected_family"),
|
||||||
|
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
|
||||||
|
)
|
||||||
|
async def test_get_clientsession(
|
||||||
|
hass: HomeAssistant, verify_ssl: bool, expected_family: int
|
||||||
|
) -> None:
|
||||||
|
"""Test init clientsession combinations."""
|
||||||
|
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
|
||||||
|
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
|
||||||
|
assert isinstance(client_session, aiohttp.ClientSession)
|
||||||
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
|
||||||
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
|
|
||||||
async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None:
|
async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None:
|
||||||
"""Test create clientsession with ssl."""
|
"""Test create clientsession with ssl."""
|
||||||
session = client.async_create_clientsession(hass, cookies={"bla": True})
|
session = client.async_create_clientsession(hass, cookies={"bla": True})
|
||||||
assert isinstance(session, aiohttp.ClientSession)
|
assert isinstance(session, aiohttp.ClientSession)
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
|
|
||||||
|
verify_ssl = True
|
||||||
|
family = 0
|
||||||
|
|
||||||
|
assert client.DATA_CLIENTSESSION not in hass.data
|
||||||
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
|
||||||
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
|
|
||||||
async def test_create_clientsession_without_ssl_and_cookies(
|
async def test_create_clientsession_without_ssl_and_cookies(
|
||||||
|
@ -80,46 +107,53 @@ async def test_create_clientsession_without_ssl_and_cookies(
|
||||||
"""Test create clientsession without ssl."""
|
"""Test create clientsession without ssl."""
|
||||||
session = client.async_create_clientsession(hass, False, cookies={"bla": True})
|
session = client.async_create_clientsession(hass, False, cookies={"bla": True})
|
||||||
assert isinstance(session, aiohttp.ClientSession)
|
assert isinstance(session, aiohttp.ClientSession)
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
|
|
||||||
|
verify_ssl = False
|
||||||
|
family = 0
|
||||||
|
|
||||||
|
assert client.DATA_CLIENTSESSION not in hass.data
|
||||||
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
|
||||||
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_clientsession_cleanup(hass: HomeAssistant) -> None:
|
@pytest.mark.parametrize(
|
||||||
"""Test init clientsession with ssl."""
|
("verify_ssl", "expected_family"),
|
||||||
client.async_get_clientsession(hass)
|
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
|
||||||
|
)
|
||||||
|
async def test_get_clientsession_cleanup(
|
||||||
|
hass: HomeAssistant, verify_ssl: bool, expected_family: int
|
||||||
|
) -> None:
|
||||||
|
"""Test init clientsession cleanup."""
|
||||||
|
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
|
||||||
|
|
||||||
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
|
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
|
assert isinstance(client_session, aiohttp.ClientSession)
|
||||||
|
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
|
||||||
|
assert isinstance(connector, aiohttp.TCPConnector)
|
||||||
|
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert hass.data[client.DATA_CLIENTSESSION].closed
|
assert client_session.closed
|
||||||
assert hass.data[client.DATA_CONNECTOR].closed
|
assert connector.closed
|
||||||
|
|
||||||
|
|
||||||
async def test_get_clientsession_cleanup_without_ssl(hass: HomeAssistant) -> None:
|
|
||||||
"""Test init clientsession with ssl."""
|
|
||||||
client.async_get_clientsession(hass, verify_ssl=False)
|
|
||||||
|
|
||||||
assert isinstance(
|
|
||||||
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
|
|
||||||
)
|
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
|
|
||||||
|
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed
|
|
||||||
assert hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed
|
|
||||||
|
|
||||||
|
|
||||||
async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
|
async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
|
||||||
"""Test closing clientsession does not work."""
|
"""Test closing clientsession does not work."""
|
||||||
|
|
||||||
|
verify_ssl = True
|
||||||
|
family = 0
|
||||||
|
|
||||||
with patch("aiohttp.ClientSession.close") as mock_close:
|
with patch("aiohttp.ClientSession.close") as mock_close:
|
||||||
session = client.async_get_clientsession(hass)
|
session = client.async_get_clientsession(hass)
|
||||||
|
|
||||||
assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
|
assert isinstance(
|
||||||
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
|
hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)],
|
||||||
|
aiohttp.ClientSession,
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
Loading…
Reference in New Issue