From 33ff6b5b6ee3d92f4bb8deb9594d67748ea23d7c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 17 Feb 2024 18:07:18 -0600 Subject: [PATCH] Avoid creating tasks for checking integrations platforms (#110795) * Avoid creating tasks for checking integrations platforms This is a followup to #110743 to avoid creating a task to check if the integration platform exists. We created tasks because we needed to await async_get_integrations but since its always called from EVENT_COMPONENT_LOADED firing, we can use the async_get_loaded_integration version which does not need to be awaited. This eliminates one task for every loaded component * there is no more race risk * reduce * coro or callback * reduce * tweak * race safe * fix type * fixes * use built-in helper to make it smaller * use built-in helper to make it smaller * use built-in helper to make it smaller * add coverage to ensure exceptions are logged * improve readability a bit * platforms --- homeassistant/helpers/integration_platform.py | 156 ++++++++---------- tests/components/cast/test_media_player.py | 1 + tests/components/repairs/test_init.py | 1 + tests/helpers/test_integration_platform.py | 85 +++++++++- 4 files changed, 152 insertions(+), 91 deletions(-) diff --git a/homeassistant/helpers/integration_platform.py b/homeassistant/helpers/integration_platform.py index 16d94edfb8b..2d9ca8afacf 100644 --- a/homeassistant/helpers/integration_platform.py +++ b/homeassistant/helpers/integration_platform.py @@ -4,14 +4,21 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable from dataclasses import dataclass +from functools import partial import logging from types import ModuleType from typing import Any from homeassistant.const import EVENT_COMPONENT_LOADED -from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.loader import Integration, async_get_integrations, bind_hass +from homeassistant.core import Event, HassJob, HomeAssistant, callback +from homeassistant.loader import ( + Integration, + async_get_integrations, + async_get_loaded_integration, + bind_hass, +) from homeassistant.setup import ATTR_COMPONENT +from homeassistant.util.logging import catch_log_exception _LOGGER = logging.getLogger(__name__) DATA_INTEGRATION_PLATFORMS = "integration_platforms" @@ -22,31 +29,12 @@ class IntegrationPlatform: """An integration platform.""" platform_name: str - process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None]] + process_job: HassJob seen_components: set[str] -async def _async_process_single_integration_platform_component( - hass: HomeAssistant, - component_name: str, - platform: ModuleType, - integration_platform: IntegrationPlatform, -) -> None: - """Process a single integration platform.""" - if component_name in integration_platform.seen_components: - return - integration_platform.seen_components.add(component_name) - try: - await integration_platform.process_platform(hass, component_name, platform) - except Exception: # pylint: disable=broad-except - _LOGGER.exception( - "Error processing platform %s.%s", - component_name, - integration_platform.platform_name, - ) - - -def _get_platform_from_integration( +@callback +def _get_platform( integration: Integration | Exception, component_name: str, platform_name: str ) -> ModuleType | None: """Get a platform from an integration.""" @@ -71,36 +59,32 @@ def _get_platform_from_integration( return None -async def _async_process_integration_platform_for_component( - hass: HomeAssistant, component_name: str +@callback +def _async_process_integration_platforms_for_component( + hass: HomeAssistant, integration_platforms: list[IntegrationPlatform], event: Event ) -> None: """Process integration platforms for a component.""" - integration_platforms: list[IntegrationPlatform] = hass.data[ - DATA_INTEGRATION_PLATFORMS - ] - integrations = await async_get_integrations(hass, (component_name,)) - tasks = [ - asyncio.create_task( - _async_process_single_integration_platform_component( - hass, - component_name, - platform, - integration_platform, - ), - name=f"process integration platform {integration_platform.platform_name} for {component_name}", - ) - for integration_platform in integration_platforms - if component_name not in integration_platform.seen_components - and ( - platform := _get_platform_from_integration( - integrations[component_name], - component_name, - integration_platform.platform_name, + component_name: str = event.data[ATTR_COMPONENT] + if "." in component_name: + return + + integration = async_get_loaded_integration(hass, component_name) + for integration_platform in integration_platforms: + if component_name in integration_platform.seen_components or not ( + platform := _get_platform( + integration, component_name, integration_platform.platform_name ) + ): + continue + integration_platform.seen_components.add(component_name) + hass.async_run_hass_job( + integration_platform.process_job, hass, component_name, platform ) - ] - if tasks: - await asyncio.gather(*tasks) + + +def _format_err(name: str, platform_name: str, *args: Any) -> str: + """Format error message.""" + return f"Exception in {name} when processing platform '{platform_name}': {args}" @bind_hass @@ -108,52 +92,44 @@ async def async_process_integration_platforms( hass: HomeAssistant, platform_name: str, # Any = platform. - process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None]], + process_platform: Callable[[HomeAssistant, str, Any], Awaitable[None] | None], ) -> None: """Process a specific platform for all current and future loaded integrations.""" if DATA_INTEGRATION_PLATFORMS not in hass.data: - hass.data[DATA_INTEGRATION_PLATFORMS] = [] - - async def _async_component_loaded(event: Event) -> None: - """Handle a new component loaded.""" - await _async_process_integration_platform_for_component( - hass, event.data[ATTR_COMPONENT] - ) - - @callback - def _async_component_loaded_filter(event: Event) -> bool: - """Handle integration platforms loaded.""" - return "." not in event.data[ATTR_COMPONENT] - + integration_platforms: list[IntegrationPlatform] = [] + hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms hass.bus.async_listen( EVENT_COMPONENT_LOADED, - _async_component_loaded, - event_filter=_async_component_loaded_filter, + partial( + _async_process_integration_platforms_for_component, + hass, + integration_platforms, + ), ) + else: + integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] - integration_platforms: list[IntegrationPlatform] = hass.data[ - DATA_INTEGRATION_PLATFORMS - ] - integration_platform = IntegrationPlatform(platform_name, process_platform, set()) + top_level_components = {comp for comp in hass.config.components if "." not in comp} + process_job = HassJob( + catch_log_exception( + process_platform, + partial(_format_err, str(process_platform), platform_name), + ), + f"process_platform {platform_name}", + ) + integration_platform = IntegrationPlatform( + platform_name, process_job, top_level_components + ) integration_platforms.append(integration_platform) - if top_level_components := [ - comp for comp in hass.config.components if "." not in comp + + if not top_level_components: + return + + integrations = await async_get_integrations(hass, top_level_components) + if futures := [ + future + for comp in top_level_components + if (platform := _get_platform(integrations[comp], comp, platform_name)) + and (future := hass.async_run_hass_job(process_job, hass, comp, platform)) ]: - integrations = await async_get_integrations(hass, top_level_components) - tasks = [ - asyncio.create_task( - _async_process_single_integration_platform_component( - hass, comp, platform, integration_platform - ), - name=f"process integration platform {platform_name} for {comp}", - ) - for comp in top_level_components - if comp not in integration_platform.seen_components - and ( - platform := _get_platform_from_integration( - integrations[comp], comp, platform_name - ) - ) - ] - if tasks: - await asyncio.gather(*tasks) + await asyncio.gather(*futures) diff --git a/tests/components/cast/test_media_player.py b/tests/components/cast/test_media_player.py index df3eb866710..66d23043935 100644 --- a/tests/components/cast/test_media_player.py +++ b/tests/components/cast/test_media_player.py @@ -2015,6 +2015,7 @@ async def test_entry_setup_no_config(hass: HomeAssistant) -> None: assert not hass.config_entries.async_entries("cast") +@pytest.mark.no_fail_on_log_exception async def test_invalid_cast_platform( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: diff --git a/tests/components/repairs/test_init.py b/tests/components/repairs/test_init.py index ce787ad00b8..977bd9b5e55 100644 --- a/tests/components/repairs/test_init.py +++ b/tests/components/repairs/test_init.py @@ -463,6 +463,7 @@ async def test_delete_issue( } +@pytest.mark.no_fail_on_log_exception async def test_non_compliant_platform( hass: HomeAssistant, hass_ws_client: WebSocketGenerator ) -> None: diff --git a/tests/helpers/test_integration_platform.py b/tests/helpers/test_integration_platform.py index ed6edcc3690..29bda99c9c6 100644 --- a/tests/helpers/test_integration_platform.py +++ b/tests/helpers/test_integration_platform.py @@ -1,9 +1,12 @@ """Test integration platform helpers.""" +from collections.abc import Callable +from types import ModuleType from unittest.mock import Mock import pytest -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.integration_platform import ( async_process_integration_platforms, ) @@ -42,6 +45,66 @@ async def test_process_integration_platforms(hass: HomeAssistant) -> None: assert processed[1][0] == "event" assert processed[1][1] == event_platform + hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"}) + await hass.async_block_till_done() + + # Firing again should not check again + assert len(processed) == 2 + + +@callback +def _process_platform_callback( + hass: HomeAssistant, domain: str, platform: ModuleType +) -> None: + """Process platform.""" + raise HomeAssistantError("Non-compliant platform") + + +async def _process_platform_coro( + hass: HomeAssistant, domain: str, platform: ModuleType +) -> None: + """Process platform.""" + raise HomeAssistantError("Non-compliant platform") + + +@pytest.mark.no_fail_on_log_exception +@pytest.mark.parametrize( + "process_platform", (_process_platform_callback, _process_platform_coro) +) +async def test_process_integration_platforms_non_compliant( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, process_platform: Callable +) -> None: + """Test processing integrations using with a non-compliant platform.""" + loaded_platform = Mock() + mock_platform(hass, "loaded_unique_880.platform_to_check", loaded_platform) + hass.config.components.add("loaded_unique_880") + + event_platform = Mock() + mock_platform(hass, "event_unique_990.platform_to_check", event_platform) + + processed = [] + + await async_process_integration_platforms( + hass, "platform_to_check", process_platform + ) + + assert len(processed) == 0 + assert "Exception in " in caplog.text + assert "platform_to_check" in caplog.text + assert "Non-compliant platform" in caplog.text + assert "loaded_unique_880" in caplog.text + caplog.clear() + + hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event_unique_990"}) + await hass.async_block_till_done() + + assert "Exception in " in caplog.text + assert "platform_to_check" in caplog.text + assert "Non-compliant platform" in caplog.text + assert "event_unique_990" in caplog.text + + assert len(processed) == 0 + async def test_broken_integration( hass: HomeAssistant, caplog: pytest.LogCaptureFixture @@ -65,3 +128,23 @@ async def test_broken_integration( assert len(processed) == 0 assert "Error importing integration loaded for platform_to_check" in caplog.text + + +async def test_process_integration_platforms_no_integrations( + hass: HomeAssistant, +) -> None: + """Test processing integrations when no integrations are loaded.""" + event_platform = Mock() + mock_platform(hass, "event.platform_to_check", event_platform) + + processed = [] + + async def _process_platform(hass, domain, platform): + """Process platform.""" + processed.append((domain, platform)) + + await async_process_integration_platforms( + hass, "platform_to_check", _process_platform + ) + + assert len(processed) == 0