Hold a lock to prevent concurrent setup of config entries (#116482)
parent
3c7cbf5794
commit
6cf1c5c1f2
|
@ -295,7 +295,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
update_listeners: list[UpdateListenerType]
|
||||
_async_cancel_retry_setup: Callable[[], Any] | None
|
||||
_on_unload: list[Callable[[], Coroutine[Any, Any, None] | None]] | None
|
||||
reload_lock: asyncio.Lock
|
||||
setup_lock: asyncio.Lock
|
||||
_reauth_lock: asyncio.Lock
|
||||
_reconfigure_lock: asyncio.Lock
|
||||
_tasks: set[asyncio.Future[Any]]
|
||||
|
@ -403,7 +403,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
_setter(self, "_on_unload", None)
|
||||
|
||||
# Reload lock to prevent conflicting reloads
|
||||
_setter(self, "reload_lock", asyncio.Lock())
|
||||
_setter(self, "setup_lock", asyncio.Lock())
|
||||
# Reauth lock to prevent concurrent reauth flows
|
||||
_setter(self, "_reauth_lock", asyncio.Lock())
|
||||
# Reconfigure lock to prevent concurrent reconfigure flows
|
||||
|
@ -702,19 +702,17 @@ class ConfigEntry(Generic[_DataT]):
|
|||
# has started so we do not block shutdown
|
||||
if not hass.is_stopping:
|
||||
hass.async_create_background_task(
|
||||
self._async_setup_retry(hass),
|
||||
self.async_setup_locked(hass),
|
||||
f"config entry retry {self.domain} {self.title}",
|
||||
eager_start=True,
|
||||
)
|
||||
|
||||
async def _async_setup_retry(self, hass: HomeAssistant) -> None:
|
||||
"""Retry setup.
|
||||
|
||||
We hold the reload lock during setup retry to ensure
|
||||
that nothing can reload the entry while we are retrying.
|
||||
"""
|
||||
async with self.reload_lock:
|
||||
await self.async_setup(hass)
|
||||
async def async_setup_locked(
|
||||
self, hass: HomeAssistant, integration: loader.Integration | None = None
|
||||
) -> None:
|
||||
"""Set up while holding the setup lock."""
|
||||
async with self.setup_lock:
|
||||
await self.async_setup(hass, integration=integration)
|
||||
|
||||
@callback
|
||||
def async_shutdown(self) -> None:
|
||||
|
@ -1794,7 +1792,15 @@ class ConfigEntries:
|
|||
# attempts.
|
||||
entry.async_cancel_retry_setup()
|
||||
|
||||
async with entry.reload_lock:
|
||||
if entry.domain not in self.hass.config.components:
|
||||
# If the component is not loaded, just load it as
|
||||
# the config entry will be loaded as well. We need
|
||||
# to do this before holding the lock to avoid a
|
||||
# deadlock.
|
||||
await async_setup_component(self.hass, entry.domain, self._hass_config)
|
||||
return entry.state is ConfigEntryState.LOADED
|
||||
|
||||
async with entry.setup_lock:
|
||||
unload_result = await self.async_unload(entry_id)
|
||||
|
||||
if not unload_result or entry.disabled_by:
|
||||
|
|
|
@ -449,7 +449,7 @@ async def _async_setup_component(
|
|||
await asyncio.gather(
|
||||
*(
|
||||
create_eager_task(
|
||||
entry.async_setup(hass, integration=integration),
|
||||
entry.async_setup_locked(hass, integration=integration),
|
||||
name=f"config entry setup {entry.title} {entry.domain} {entry.entry_id}",
|
||||
)
|
||||
for entry in entries
|
||||
|
|
|
@ -324,6 +324,7 @@ async def test_user_flow_already_configured_host_changed_reloads_entry(
|
|||
state=ConfigEntryState.LOADED,
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
hass.config.components.add(DOMAIN)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
|
@ -640,6 +641,7 @@ async def test_zeroconf_flow_already_configured_host_changed_reloads_entry(
|
|||
state=ConfigEntryState.LOADED,
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
hass.config.components.add(DOMAIN)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
|
@ -769,6 +771,7 @@ async def test_reauth_flow_success(
|
|||
state=ConfigEntryState.LOADED,
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
hass.config.components.add(DOMAIN)
|
||||
|
||||
mock_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
|
|
@ -251,6 +251,7 @@ async def test_reload_entry(hass: HomeAssistant, client) -> None:
|
|||
domain="kitchen_sink", state=core_ce.ConfigEntryState.LOADED
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add("kitchen_sink")
|
||||
resp = await client.post(
|
||||
f"/api/config/config_entries/entry/{entry.entry_id}/reload"
|
||||
)
|
||||
|
@ -298,6 +299,7 @@ async def test_reload_entry_in_failed_state(
|
|||
"""Test reloading an entry via the API that has already failed to unload."""
|
||||
entry = MockConfigEntry(domain="demo", state=core_ce.ConfigEntryState.FAILED_UNLOAD)
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add("demo")
|
||||
resp = await client.post(
|
||||
f"/api/config/config_entries/entry/{entry.entry_id}/reload"
|
||||
)
|
||||
|
@ -326,6 +328,7 @@ async def test_reload_entry_in_setup_retry(
|
|||
entry = MockConfigEntry(domain="comp", state=core_ce.ConfigEntryState.SETUP_RETRY)
|
||||
entry.supports_unload = True
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add("comp")
|
||||
|
||||
with patch.dict(HANDLERS, {"comp": ConfigFlow, "test": ConfigFlow}):
|
||||
resp = await client.post(
|
||||
|
@ -1109,6 +1112,7 @@ async def test_update_prefrences(
|
|||
domain="kitchen_sink", state=core_ce.ConfigEntryState.LOADED
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add("kitchen_sink")
|
||||
|
||||
assert entry.pref_disable_new_entities is False
|
||||
assert entry.pref_disable_polling is False
|
||||
|
@ -1209,6 +1213,7 @@ async def test_disable_entry(
|
|||
)
|
||||
entry.add_to_hass(hass)
|
||||
assert entry.disabled_by is None
|
||||
hass.config.components.add("kitchen_sink")
|
||||
|
||||
# Disable
|
||||
await ws_client.send_json(
|
||||
|
|
|
@ -1873,6 +1873,7 @@ async def test_reload_entry_with_restored_subscriptions(
|
|||
# Setup the MQTT entry
|
||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add(mqtt.DOMAIN)
|
||||
mqtt_client_mock.connect.return_value = 0
|
||||
with patch("homeassistant.config.load_yaml_config_file", return_value={}):
|
||||
await entry.async_setup(hass)
|
||||
|
|
|
@ -279,6 +279,7 @@ async def test_form_valid_reauth(
|
|||
) -> None:
|
||||
"""Test that we can handle a valid reauth."""
|
||||
mock_config_entry.mock_state(hass, ConfigEntryState.LOADED)
|
||||
hass.config.components.add(DOMAIN)
|
||||
mock_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -328,6 +329,7 @@ async def test_form_valid_reauth_with_mfa(
|
|||
},
|
||||
)
|
||||
mock_config_entry.mock_state(hass, ConfigEntryState.LOADED)
|
||||
hass.config.components.add(DOMAIN)
|
||||
mock_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
|
@ -825,7 +825,7 @@ async def test_as_dict(snapshot: SnapshotAssertion) -> None:
|
|||
"error_reason_translation_placeholders",
|
||||
"_async_cancel_retry_setup",
|
||||
"_on_unload",
|
||||
"reload_lock",
|
||||
"setup_lock",
|
||||
"_reauth_lock",
|
||||
"_tasks",
|
||||
"_background_tasks",
|
||||
|
@ -1632,7 +1632,6 @@ async def test_entry_reload_succeed(
|
|||
mock_platform(hass, "comp.config_flow", None)
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 1
|
||||
assert len(async_setup.mock_calls) == 1
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||
|
@ -1707,6 +1706,8 @@ async def test_entry_reload_error(
|
|||
),
|
||||
)
|
||||
|
||||
hass.config.components.add("comp")
|
||||
|
||||
with pytest.raises(config_entries.OperationNotAllowed, match=str(state)):
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
|
||||
|
@ -1738,8 +1739,11 @@ async def test_entry_disable_succeed(
|
|||
),
|
||||
)
|
||||
mock_platform(hass, "comp.config_flow", None)
|
||||
hass.config.components.add("comp")
|
||||
|
||||
# Disable
|
||||
assert len(async_setup.mock_calls) == 0
|
||||
assert len(async_setup_entry.mock_calls) == 0
|
||||
assert await manager.async_set_disabled_by(
|
||||
entry.entry_id, config_entries.ConfigEntryDisabler.USER
|
||||
)
|
||||
|
@ -1751,7 +1755,7 @@ async def test_entry_disable_succeed(
|
|||
# Enable
|
||||
assert await manager.async_set_disabled_by(entry.entry_id, None)
|
||||
assert len(async_unload_entry.mock_calls) == 1
|
||||
assert len(async_setup.mock_calls) == 1
|
||||
assert len(async_setup.mock_calls) == 0
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||
|
||||
|
@ -1775,6 +1779,7 @@ async def test_entry_disable_without_reload_support(
|
|||
),
|
||||
)
|
||||
mock_platform(hass, "comp.config_flow", None)
|
||||
hass.config.components.add("comp")
|
||||
|
||||
# Disable
|
||||
assert not await manager.async_set_disabled_by(
|
||||
|
@ -1951,7 +1956,7 @@ async def test_reload_entry_entity_registry_works(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_unload_entry.mock_calls) == 2
|
||||
assert len(mock_unload_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_unique_id_persisted(
|
||||
|
@ -3392,6 +3397,7 @@ async def test_entry_reload_calls_on_unload_listeners(
|
|||
),
|
||||
)
|
||||
mock_platform(hass, "comp.config_flow", None)
|
||||
hass.config.components.add("comp")
|
||||
|
||||
mock_unload_callback = Mock()
|
||||
|
||||
|
@ -3944,8 +3950,9 @@ async def test_deprecated_disabled_by_str_set(
|
|||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test deprecated str set disabled_by enumizes and logs a warning."""
|
||||
entry = MockConfigEntry()
|
||||
entry = MockConfigEntry(domain="comp")
|
||||
entry.add_to_manager(manager)
|
||||
hass.config.components.add("comp")
|
||||
assert await manager.async_set_disabled_by(
|
||||
entry.entry_id, config_entries.ConfigEntryDisabler.USER.value
|
||||
)
|
||||
|
@ -3963,6 +3970,47 @@ async def test_entry_reload_concurrency(
|
|||
async_setup = AsyncMock(return_value=True)
|
||||
loaded = 1
|
||||
|
||||
async def _async_setup_entry(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
nonlocal loaded
|
||||
loaded += 1
|
||||
return loaded == 1
|
||||
|
||||
async def _async_unload_entry(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
nonlocal loaded
|
||||
loaded -= 1
|
||||
return loaded == 0
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup=async_setup,
|
||||
async_setup_entry=_async_setup_entry,
|
||||
async_unload_entry=_async_unload_entry,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "comp.config_flow", None)
|
||||
hass.config.components.add("comp")
|
||||
tasks = [
|
||||
asyncio.create_task(manager.async_reload(entry.entry_id)) for _ in range(15)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||
assert loaded == 1
|
||||
|
||||
|
||||
async def test_entry_reload_concurrency_not_setup_setup(
|
||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||
) -> None:
|
||||
"""Test multiple reload calls do not cause a reload race."""
|
||||
entry = MockConfigEntry(domain="comp", state=config_entries.ConfigEntryState.LOADED)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
async_setup = AsyncMock(return_value=True)
|
||||
loaded = 0
|
||||
|
||||
async def _async_setup_entry(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
nonlocal loaded
|
||||
|
@ -4074,6 +4122,7 @@ async def test_disallow_entry_reload_with_setup_in_progress(
|
|||
domain="comp", state=config_entries.ConfigEntryState.SETUP_IN_PROGRESS
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
hass.config.components.add("comp")
|
||||
|
||||
with pytest.raises(
|
||||
config_entries.OperationNotAllowed,
|
||||
|
@ -5016,3 +5065,48 @@ async def test_updating_non_added_entry_raises(hass: HomeAssistant) -> None:
|
|||
|
||||
with pytest.raises(config_entries.UnknownEntry, match=entry.entry_id):
|
||||
hass.config_entries.async_update_entry(entry, unique_id="new_id")
|
||||
|
||||
|
||||
async def test_reload_during_setup(hass: HomeAssistant) -> None:
|
||||
"""Test reload during setup waits."""
|
||||
entry = MockConfigEntry(domain="comp", data={"value": "initial"})
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
setup_start_future = hass.loop.create_future()
|
||||
setup_finish_future = hass.loop.create_future()
|
||||
in_setup = False
|
||||
setup_calls = 0
|
||||
|
||||
async def mock_async_setup_entry(hass, entry):
|
||||
"""Mock setting up an entry."""
|
||||
nonlocal in_setup
|
||||
nonlocal setup_calls
|
||||
setup_calls += 1
|
||||
assert not in_setup
|
||||
in_setup = True
|
||||
setup_start_future.set_result(None)
|
||||
await setup_finish_future
|
||||
in_setup = False
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup_entry=mock_async_setup_entry,
|
||||
async_unload_entry=AsyncMock(return_value=True),
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "comp.config_flow", None)
|
||||
|
||||
setup_task = hass.async_create_task(async_setup_component(hass, "comp", {}))
|
||||
|
||||
await setup_start_future # ensure we are in the setup
|
||||
reload_task = hass.async_create_task(
|
||||
hass.config_entries.async_reload(entry.entry_id)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
setup_finish_future.set_result(None)
|
||||
await setup_task
|
||||
await reload_task
|
||||
assert setup_calls == 2
|
||||
|
|
Loading…
Reference in New Issue