From ed726db97408616eb1105f8cb8e8820dabab8b6a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 1 Feb 2024 12:34:23 -0600 Subject: [PATCH] Fix race in loading service descriptions (#109316) --- homeassistant/helpers/service.py | 5 ++ tests/helpers/test_service.py | 79 ++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 5a9786eb0fa..30516e3a099 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -608,6 +608,11 @@ async def async_get_all_descriptions( # Files we loaded for missing descriptions loaded: dict[str, JSON_TYPE] = {} + # We try to avoid making a copy in the event the cache is good, + # but now we must make a copy in case new services get added + # while we are loading the missing ones so we do not + # add the new ones to the cache without their descriptions + services = {domain: service.copy() for domain, service in services.items()} if domains_with_missing_services: ints_or_excs = await async_get_integrations(hass, domains_with_missing_services) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 07e68e081b3..90f9b65aaba 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -1,4 +1,5 @@ """Test service helpers.""" +import asyncio from collections.abc import Iterable from copy import deepcopy from typing import Any @@ -782,6 +783,84 @@ async def test_async_get_all_descriptions_dynamically_created_services( } +async def test_async_get_all_descriptions_new_service_added_while_loading( + hass: HomeAssistant, +) -> None: + """Test async_get_all_descriptions when a new service is added while loading translations.""" + group = hass.components.group + group_config = {group.DOMAIN: {}} + await async_setup_component(hass, group.DOMAIN, group_config) + descriptions = await service.async_get_all_descriptions(hass) + + assert len(descriptions) == 1 + + assert "description" in descriptions["group"]["reload"] + assert "fields" in descriptions["group"]["reload"] + + logger = hass.components.logger + logger_domain = logger.DOMAIN + logger_config = {logger_domain: {}} + + translations_called = asyncio.Event() + translations_wait = asyncio.Event() + + async def async_get_translations( + hass: HomeAssistant, + language: str, + category: str, + integrations: Iterable[str] | None = None, + config_flow: bool | None = None, + ) -> dict[str, Any]: + """Return all backend translations.""" + translations_called.set() + await translations_wait.wait() + translation_key_prefix = f"component.{logger_domain}.services.set_default_level" + return { + f"{translation_key_prefix}.name": "Translated name", + f"{translation_key_prefix}.description": "Translated description", + f"{translation_key_prefix}.fields.level.name": "Field name", + f"{translation_key_prefix}.fields.level.description": "Field description", + f"{translation_key_prefix}.fields.level.example": "Field example", + } + + with patch( + "homeassistant.helpers.service.translation.async_get_translations", + side_effect=async_get_translations, + ): + await async_setup_component(hass, logger_domain, logger_config) + task = asyncio.create_task(service.async_get_all_descriptions(hass)) + await translations_called.wait() + # Now register a new service while translations are being loaded + hass.services.async_register(logger_domain, "new_service", lambda x: None, None) + service.async_set_service_schema( + hass, logger_domain, "new_service", {"description": "new service"} + ) + translations_wait.set() + descriptions = await task + + # Two domains should be present + assert len(descriptions) == 2 + + logger_descriptions = descriptions[logger_domain] + + # The new service was loaded after the translations were loaded + # so it should not appear until the next time we fetch + assert "new_service" not in logger_descriptions + + set_default_level = logger_descriptions["set_default_level"] + + assert set_default_level["name"] == "Translated name" + assert set_default_level["description"] == "Translated description" + set_default_level_fields = set_default_level["fields"] + assert set_default_level_fields["level"]["name"] == "Field name" + assert set_default_level_fields["level"]["description"] == "Field description" + assert set_default_level_fields["level"]["example"] == "Field example" + + descriptions = await service.async_get_all_descriptions(hass) + assert "description" in descriptions[logger_domain]["new_service"] + assert descriptions[logger_domain]["new_service"]["description"] == "new service" + + async def test_register_with_mixed_case(hass: HomeAssistant) -> None: """Test registering a service with mixed case.