diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 9e4791fdef6..819b813832d 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -2,7 +2,15 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping +from collections import UserDict +from collections.abc import ( + Callable, + Coroutine, + Generator, + Iterable, + Mapping, + ValuesView, +) from contextvars import ContextVar from copy import deepcopy from enum import Enum, StrEnum @@ -336,6 +344,13 @@ class ConfigEntry: self._tries = 0 self._setup_again_job: HassJob | None = None + def __repr__(self) -> str: + """Representation of ConfigEntry.""" + return ( + f"" + ) + async def async_setup( self, hass: HomeAssistant, @@ -1057,6 +1072,67 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): ) +class ConfigEntryItems(UserDict[str, ConfigEntry]): + """Container for config items, maps config_entry_id -> entry. + + Maintains two additional indexes: + - domain -> list[ConfigEntry] + - domain -> unique_id -> ConfigEntry + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._domain_index: dict[str, list[ConfigEntry]] = {} + self._domain_unique_id_index: dict[str, dict[str, ConfigEntry]] = {} + + def values(self) -> ValuesView[ConfigEntry]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + + def __setitem__(self, entry_id: str, entry: ConfigEntry) -> None: + """Add an item.""" + data = self.data + if entry_id in data: + # This is likely a bug in a test that is adding the same entry twice. + # In the future, once we have fixed the tests, this will raise HomeAssistantError. + _LOGGER.error("An entry with the id %s already exists", entry_id) + self._unindex_entry(entry_id) + data[entry_id] = entry + self._domain_index.setdefault(entry.domain, []).append(entry) + if entry.unique_id is not None: + self._domain_unique_id_index.setdefault(entry.domain, {})[ + entry.unique_id + ] = entry + + def _unindex_entry(self, entry_id: str) -> None: + """Unindex an entry.""" + entry = self.data[entry_id] + domain = entry.domain + self._domain_index[domain].remove(entry) + if not self._domain_index[domain]: + del self._domain_index[domain] + if (unique_id := entry.unique_id) is not None: + del self._domain_unique_id_index[domain][unique_id] + if not self._domain_unique_id_index[domain]: + del self._domain_unique_id_index[domain] + + def __delitem__(self, entry_id: str) -> None: + """Remove an item.""" + self._unindex_entry(entry_id) + super().__delitem__(entry_id) + + def get_entries_for_domain(self, domain: str) -> list[ConfigEntry]: + """Get entries for a domain.""" + return self._domain_index.get(domain, []) + + def get_entry_by_domain_and_unique_id( + self, domain: str, unique_id: str + ) -> ConfigEntry | None: + """Get entry by domain and unique id.""" + return self._domain_unique_id_index.get(domain, {}).get(unique_id) + + class ConfigEntries: """Manage the configuration entries. @@ -1069,8 +1145,7 @@ class ConfigEntries: self.flow = ConfigEntriesFlowManager(hass, self, hass_config) self.options = OptionsFlowManager(hass) self._hass_config = hass_config - self._entries: dict[str, ConfigEntry] = {} - self._domain_index: dict[str, list[ConfigEntry]] = {} + self._entries = ConfigEntryItems() self._store = storage.Store[dict[str, list[dict[str, Any]]]]( hass, STORAGE_VERSION, STORAGE_KEY ) @@ -1093,23 +1168,29 @@ class ConfigEntries: @callback def async_get_entry(self, entry_id: str) -> ConfigEntry | None: """Return entry with matching entry_id.""" - return self._entries.get(entry_id) + return self._entries.data.get(entry_id) @callback def async_entries(self, domain: str | None = None) -> list[ConfigEntry]: """Return all entries or entries for a specific domain.""" if domain is None: return list(self._entries.values()) - return list(self._domain_index.get(domain, [])) + return list(self._entries.get_entries_for_domain(domain)) + + @callback + def async_entry_for_domain_unique_id( + self, domain: str, unique_id: str + ) -> ConfigEntry | None: + """Return entry for a domain with a matching unique id.""" + return self._entries.get_entry_by_domain_and_unique_id(domain, unique_id) async def async_add(self, entry: ConfigEntry) -> None: """Add and setup an entry.""" - if entry.entry_id in self._entries: + if entry.entry_id in self._entries.data: raise HomeAssistantError( f"An entry with the id {entry.entry_id} already exists." ) self._entries[entry.entry_id] = entry - self._domain_index.setdefault(entry.domain, []).append(entry) self._async_dispatch(ConfigEntryChange.ADDED, entry) await self.async_setup(entry.entry_id) self._async_schedule_save() @@ -1127,9 +1208,6 @@ class ConfigEntries: await entry.async_remove(self.hass) del self._entries[entry.entry_id] - self._domain_index[entry.domain].remove(entry) - if not self._domain_index[entry.domain]: - del self._domain_index[entry.domain] self._async_schedule_save() dev_reg = device_registry.async_get(self.hass) @@ -1189,13 +1267,10 @@ class ConfigEntries: self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown) if config is None: - self._entries = {} - self._domain_index = {} + self._entries = ConfigEntryItems() return - entries = {} - domain_index: dict[str, list[ConfigEntry]] = {} - + entries: ConfigEntryItems = ConfigEntryItems() for entry in config["entries"]: pref_disable_new_entities = entry.get("pref_disable_new_entities") @@ -1230,9 +1305,7 @@ class ConfigEntries: pref_disable_polling=entry.get("pref_disable_polling"), ) entries[entry_id] = config_entry - domain_index.setdefault(domain, []).append(config_entry) - self._domain_index = domain_index self._entries = entries async def async_setup(self, entry_id: str) -> bool: @@ -1365,8 +1438,15 @@ class ConfigEntries: """ changed = False + if unique_id is not UNDEFINED and entry.unique_id != unique_id: + # Reindex the entry if the unique_id has changed + entry_id = entry.entry_id + del self._entries[entry_id] + entry.unique_id = unique_id + self._entries[entry_id] = entry + changed = True + for attr, value in ( - ("unique_id", unique_id), ("title", title), ("pref_disable_new_entities", pref_disable_new_entities), ("pref_disable_polling", pref_disable_polling), @@ -1579,38 +1659,41 @@ class ConfigFlow(data_entry_flow.FlowHandler): if self.unique_id is None: return - for entry in self._async_current_entries(include_ignore=True): - if entry.unique_id != self.unique_id: - continue - should_reload = False - if ( - updates is not None - and self.hass.config_entries.async_update_entry( - entry, data={**entry.data, **updates} - ) - and reload_on_update - and entry.state - in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY) - ): - # Existing config entry present, and the - # entry data just changed - should_reload = True - elif ( - self.source in DISCOVERY_SOURCES - and entry.state is ConfigEntryState.SETUP_RETRY - ): - # Existing config entry present in retry state, and we - # just discovered the unique id so we know its online - should_reload = True - # Allow ignored entries to be configured on manual user step - if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER: - continue - if should_reload: - self.hass.async_create_task( - self.hass.config_entries.async_reload(entry.entry_id), - f"config entry reload {entry.title} {entry.domain} {entry.entry_id}", - ) - raise data_entry_flow.AbortFlow(error) + if not ( + entry := self.hass.config_entries.async_entry_for_domain_unique_id( + self.handler, self.unique_id + ) + ): + return + + should_reload = False + if ( + updates is not None + and self.hass.config_entries.async_update_entry( + entry, data={**entry.data, **updates} + ) + and reload_on_update + and entry.state in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY) + ): + # Existing config entry present, and the + # entry data just changed + should_reload = True + elif ( + self.source in DISCOVERY_SOURCES + and entry.state is ConfigEntryState.SETUP_RETRY + ): + # Existing config entry present in retry state, and we + # just discovered the unique id so we know its online + should_reload = True + # Allow ignored entries to be configured on manual user step + if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER: + return + if should_reload: + self.hass.async_create_task( + self.hass.config_entries.async_reload(entry.entry_id), + f"config entry reload {entry.title} {entry.domain} {entry.entry_id}", + ) + raise data_entry_flow.AbortFlow(error) async def async_set_unique_id( self, unique_id: str | None = None, *, raise_on_progress: bool = True @@ -1639,11 +1722,9 @@ class ConfigFlow(data_entry_flow.FlowHandler): ): self.hass.config_entries.flow.async_abort(progress["flow_id"]) - for entry in self._async_current_entries(include_ignore=True): - if entry.unique_id == unique_id: - return entry - - return None + return self.hass.config_entries.async_entry_for_domain_unique_id( + self.handler, unique_id + ) @callback def _set_confirm_only( diff --git a/tests/common.py b/tests/common.py index 35171799728..8b5a16c7104 100644 --- a/tests/common.py +++ b/tests/common.py @@ -939,12 +939,10 @@ class MockConfigEntry(config_entries.ConfigEntry): def add_to_hass(self, hass: HomeAssistant) -> None: """Test helper to add entry to hass.""" hass.config_entries._entries[self.entry_id] = self - hass.config_entries._domain_index.setdefault(self.domain, []).append(self) def add_to_manager(self, manager: config_entries.ConfigEntries) -> None: """Test helper to add entry to entry manager.""" manager._entries[self.entry_id] = self - manager._domain_index.setdefault(self.domain, []).append(self) def patch_yaml_files(files_dict, endswith=True): diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index e9989b6839e..fd74a2e6286 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3123,6 +3123,9 @@ async def test_updating_entry_with_and_without_changes( state=config_entries.ConfigEntryState.SETUP_ERROR, ) entry.add_to_manager(manager) + assert "abc123" in str(entry) + + assert manager.async_entry_for_domain_unique_id("test", "abc123") is entry assert manager.async_update_entry(entry) is False @@ -3138,6 +3141,10 @@ async def test_updating_entry_with_and_without_changes( assert manager.async_update_entry(entry, **change) is True assert manager.async_update_entry(entry, **change) is False + assert manager.async_entry_for_domain_unique_id("test", "abc123") is None + assert manager.async_entry_for_domain_unique_id("test", "abcd1234") is entry + assert "abcd1234" in str(entry) + async def test_entry_reload_calls_on_unload_listeners( hass: HomeAssistant, manager: config_entries.ConfigEntries @@ -4127,3 +4134,13 @@ async def test_preview_not_supported( ) assert result["preview"] is None + + +def test_raise_trying_to_add_same_config_entry_twice( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test we log an error if trying to add same config entry twice.""" + entry = MockConfigEntry(domain="test") + entry.add_to_hass(hass) + entry.add_to_hass(hass) + assert f"An entry with the id {entry.entry_id} already exists" in caplog.text