Fix race in loading service descriptions (#109316)

pull/109325/head
J. Nick Koston 2024-02-01 12:34:23 -06:00 committed by GitHub
parent c61a2b46d4
commit ed726db974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 0 deletions

View File

@ -608,6 +608,11 @@ async def async_get_all_descriptions(
# Files we loaded for missing descriptions
loaded: dict[str, JSON_TYPE] = {}
# We try to avoid making a copy in the event the cache is good,
# but now we must make a copy in case new services get added
# while we are loading the missing ones so we do not
# add the new ones to the cache without their descriptions
services = {domain: service.copy() for domain, service in services.items()}
if domains_with_missing_services:
ints_or_excs = await async_get_integrations(hass, domains_with_missing_services)

View File

@ -1,4 +1,5 @@
"""Test service helpers."""
import asyncio
from collections.abc import Iterable
from copy import deepcopy
from typing import Any
@ -782,6 +783,84 @@ async def test_async_get_all_descriptions_dynamically_created_services(
}
async def test_async_get_all_descriptions_new_service_added_while_loading(
hass: HomeAssistant,
) -> None:
"""Test async_get_all_descriptions when a new service is added while loading translations."""
group = hass.components.group
group_config = {group.DOMAIN: {}}
await async_setup_component(hass, group.DOMAIN, group_config)
descriptions = await service.async_get_all_descriptions(hass)
assert len(descriptions) == 1
assert "description" in descriptions["group"]["reload"]
assert "fields" in descriptions["group"]["reload"]
logger = hass.components.logger
logger_domain = logger.DOMAIN
logger_config = {logger_domain: {}}
translations_called = asyncio.Event()
translations_wait = asyncio.Event()
async def async_get_translations(
hass: HomeAssistant,
language: str,
category: str,
integrations: Iterable[str] | None = None,
config_flow: bool | None = None,
) -> dict[str, Any]:
"""Return all backend translations."""
translations_called.set()
await translations_wait.wait()
translation_key_prefix = f"component.{logger_domain}.services.set_default_level"
return {
f"{translation_key_prefix}.name": "Translated name",
f"{translation_key_prefix}.description": "Translated description",
f"{translation_key_prefix}.fields.level.name": "Field name",
f"{translation_key_prefix}.fields.level.description": "Field description",
f"{translation_key_prefix}.fields.level.example": "Field example",
}
with patch(
"homeassistant.helpers.service.translation.async_get_translations",
side_effect=async_get_translations,
):
await async_setup_component(hass, logger_domain, logger_config)
task = asyncio.create_task(service.async_get_all_descriptions(hass))
await translations_called.wait()
# Now register a new service while translations are being loaded
hass.services.async_register(logger_domain, "new_service", lambda x: None, None)
service.async_set_service_schema(
hass, logger_domain, "new_service", {"description": "new service"}
)
translations_wait.set()
descriptions = await task
# Two domains should be present
assert len(descriptions) == 2
logger_descriptions = descriptions[logger_domain]
# The new service was loaded after the translations were loaded
# so it should not appear until the next time we fetch
assert "new_service" not in logger_descriptions
set_default_level = logger_descriptions["set_default_level"]
assert set_default_level["name"] == "Translated name"
assert set_default_level["description"] == "Translated description"
set_default_level_fields = set_default_level["fields"]
assert set_default_level_fields["level"]["name"] == "Field name"
assert set_default_level_fields["level"]["description"] == "Field description"
assert set_default_level_fields["level"]["example"] == "Field example"
descriptions = await service.async_get_all_descriptions(hass)
assert "description" in descriptions[logger_domain]["new_service"]
assert descriptions[logger_domain]["new_service"]["description"] == "new service"
async def test_register_with_mixed_case(hass: HomeAssistant) -> None:
"""Test registering a service with mixed case.