Avoid creating tasks for checking integrations platforms (#110795)

* Avoid creating tasks for checking integrations platforms

This is a followup to #110743 to avoid creating a task to check
if the integration platform exists. We created tasks because
we needed to await async_get_integrations but since its always
called from EVENT_COMPONENT_LOADED firing, we can use the
async_get_loaded_integration version which does not need
to be awaited. This eliminates one task for every loaded
component

* there is no more race risk

* reduce

* coro or callback

* reduce

* tweak

* race safe

* fix type

* fixes

* use built-in helper to make it smaller

* use built-in helper to make it smaller

* use built-in helper to make it smaller

* add coverage to ensure exceptions are logged

* improve readability a bit

* platforms
pull/110825/head
J. Nick Koston 2024-02-17 18:07:18 -06:00 committed by GitHub
parent a656e14b20
commit 33ff6b5b6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 152 additions and 91 deletions

View File

@ -4,14 +4,21 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from functools import partial
import logging
from types import ModuleType
from typing import Any
from homeassistant.const import EVENT_COMPONENT_LOADED
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import Integration, async_get_integrations, bind_hass
from homeassistant.core import Event, HassJob, HomeAssistant, callback
from homeassistant.loader import (
Integration,
async_get_integrations,
async_get_loaded_integration,
bind_hass,
)
from homeassistant.setup import ATTR_COMPONENT
from homeassistant.util.logging import catch_log_exception
_LOGGER = logging.getLogger(__name__)
DATA_INTEGRATION_PLATFORMS = "integration_platforms"
@ -22,31 +29,12 @@ class IntegrationPlatform:
"""An integration platform."""
platform_name: str
process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None]]
process_job: HassJob
seen_components: set[str]
async def _async_process_single_integration_platform_component(
hass: HomeAssistant,
component_name: str,
platform: ModuleType,
integration_platform: IntegrationPlatform,
) -> None:
"""Process a single integration platform."""
if component_name in integration_platform.seen_components:
return
integration_platform.seen_components.add(component_name)
try:
await integration_platform.process_platform(hass, component_name, platform)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error processing platform %s.%s",
component_name,
integration_platform.platform_name,
)
def _get_platform_from_integration(
@callback
def _get_platform(
integration: Integration | Exception, component_name: str, platform_name: str
) -> ModuleType | None:
"""Get a platform from an integration."""
@ -71,36 +59,32 @@ def _get_platform_from_integration(
return None
async def _async_process_integration_platform_for_component(
hass: HomeAssistant, component_name: str
@callback
def _async_process_integration_platforms_for_component(
hass: HomeAssistant, integration_platforms: list[IntegrationPlatform], event: Event
) -> None:
"""Process integration platforms for a component."""
integration_platforms: list[IntegrationPlatform] = hass.data[
DATA_INTEGRATION_PLATFORMS
]
integrations = await async_get_integrations(hass, (component_name,))
tasks = [
asyncio.create_task(
_async_process_single_integration_platform_component(
hass,
component_name,
platform,
integration_platform,
),
name=f"process integration platform {integration_platform.platform_name} for {component_name}",
)
for integration_platform in integration_platforms
if component_name not in integration_platform.seen_components
and (
platform := _get_platform_from_integration(
integrations[component_name],
component_name,
integration_platform.platform_name,
component_name: str = event.data[ATTR_COMPONENT]
if "." in component_name:
return
integration = async_get_loaded_integration(hass, component_name)
for integration_platform in integration_platforms:
if component_name in integration_platform.seen_components or not (
platform := _get_platform(
integration, component_name, integration_platform.platform_name
)
):
continue
integration_platform.seen_components.add(component_name)
hass.async_run_hass_job(
integration_platform.process_job, hass, component_name, platform
)
]
if tasks:
await asyncio.gather(*tasks)
def _format_err(name: str, platform_name: str, *args: Any) -> str:
"""Format error message."""
return f"Exception in {name} when processing platform '{platform_name}': {args}"
@bind_hass
@ -108,52 +92,44 @@ async def async_process_integration_platforms(
hass: HomeAssistant,
platform_name: str,
# Any = platform.
process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None]],
process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None] | None],
) -> None:
"""Process a specific platform for all current and future loaded integrations."""
if DATA_INTEGRATION_PLATFORMS not in hass.data:
hass.data[DATA_INTEGRATION_PLATFORMS] = []
async def _async_component_loaded(event: Event) -> None:
"""Handle a new component loaded."""
await _async_process_integration_platform_for_component(
hass, event.data[ATTR_COMPONENT]
)
@callback
def _async_component_loaded_filter(event: Event) -> bool:
"""Handle integration platforms loaded."""
return "." not in event.data[ATTR_COMPONENT]
integration_platforms: list[IntegrationPlatform] = []
hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms
hass.bus.async_listen(
EVENT_COMPONENT_LOADED,
_async_component_loaded,
event_filter=_async_component_loaded_filter,
partial(
_async_process_integration_platforms_for_component,
hass,
integration_platforms,
),
)
else:
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS]
integration_platforms: list[IntegrationPlatform] = hass.data[
DATA_INTEGRATION_PLATFORMS
]
integration_platform = IntegrationPlatform(platform_name, process_platform, set())
top_level_components = {comp for comp in hass.config.components if "." not in comp}
process_job = HassJob(
catch_log_exception(
process_platform,
partial(_format_err, str(process_platform), platform_name),
),
f"process_platform {platform_name}",
)
integration_platform = IntegrationPlatform(
platform_name, process_job, top_level_components
)
integration_platforms.append(integration_platform)
if top_level_components := [
comp for comp in hass.config.components if "." not in comp
if not top_level_components:
return
integrations = await async_get_integrations(hass, top_level_components)
if futures := [
future
for comp in top_level_components
if (platform := _get_platform(integrations[comp], comp, platform_name))
and (future := hass.async_run_hass_job(process_job, hass, comp, platform))
]:
integrations = await async_get_integrations(hass, top_level_components)
tasks = [
asyncio.create_task(
_async_process_single_integration_platform_component(
hass, comp, platform, integration_platform
),
name=f"process integration platform {platform_name} for {comp}",
)
for comp in top_level_components
if comp not in integration_platform.seen_components
and (
platform := _get_platform_from_integration(
integrations[comp], comp, platform_name
)
)
]
if tasks:
await asyncio.gather(*tasks)
await asyncio.gather(*futures)

View File

@ -2015,6 +2015,7 @@ async def test_entry_setup_no_config(hass: HomeAssistant) -> None:
assert not hass.config_entries.async_entries("cast")
@pytest.mark.no_fail_on_log_exception
async def test_invalid_cast_platform(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:

View File

@ -463,6 +463,7 @@ async def test_delete_issue(
}
@pytest.mark.no_fail_on_log_exception
async def test_non_compliant_platform(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:

View File

@ -1,9 +1,12 @@
"""Test integration platform helpers."""
from collections.abc import Callable
from types import ModuleType
from unittest.mock import Mock
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
@ -42,6 +45,66 @@ async def test_process_integration_platforms(hass: HomeAssistant) -> None:
assert processed[1][0] == "event"
assert processed[1][1] == event_platform
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
await hass.async_block_till_done()
# Firing again should not check again
assert len(processed) == 2
@callback
def _process_platform_callback(
hass: HomeAssistant, domain: str, platform: ModuleType
) -> None:
"""Process platform."""
raise HomeAssistantError("Non-compliant platform")
async def _process_platform_coro(
hass: HomeAssistant, domain: str, platform: ModuleType
) -> None:
"""Process platform."""
raise HomeAssistantError("Non-compliant platform")
@pytest.mark.no_fail_on_log_exception
@pytest.mark.parametrize(
"process_platform", (_process_platform_callback, _process_platform_coro)
)
async def test_process_integration_platforms_non_compliant(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, process_platform: Callable
) -> None:
"""Test processing integrations using with a non-compliant platform."""
loaded_platform = Mock()
mock_platform(hass, "loaded_unique_880.platform_to_check", loaded_platform)
hass.config.components.add("loaded_unique_880")
event_platform = Mock()
mock_platform(hass, "event_unique_990.platform_to_check", event_platform)
processed = []
await async_process_integration_platforms(
hass, "platform_to_check", process_platform
)
assert len(processed) == 0
assert "Exception in " in caplog.text
assert "platform_to_check" in caplog.text
assert "Non-compliant platform" in caplog.text
assert "loaded_unique_880" in caplog.text
caplog.clear()
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event_unique_990"})
await hass.async_block_till_done()
assert "Exception in " in caplog.text
assert "platform_to_check" in caplog.text
assert "Non-compliant platform" in caplog.text
assert "event_unique_990" in caplog.text
assert len(processed) == 0
async def test_broken_integration(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
@ -65,3 +128,23 @@ async def test_broken_integration(
assert len(processed) == 0
assert "Error importing integration loaded for platform_to_check" in caplog.text
async def test_process_integration_platforms_no_integrations(
hass: HomeAssistant,
) -> None:
"""Test processing integrations when no integrations are loaded."""
event_platform = Mock()
mock_platform(hass, "event.platform_to_check", event_platform)
processed = []
async def _process_platform(hass, domain, platform):
"""Process platform."""
processed.append((domain, platform))
await async_process_integration_platforms(
hass, "platform_to_check", _process_platform
)
assert len(processed) == 0