Always do thread safety checks when writing state for custom components (#116044)
parent
6f2a2ba46e
commit
79b4889812
|
@ -521,6 +521,7 @@ class Entity(
|
||||||
# While not purely typed, it makes typehinting more useful for us
|
# While not purely typed, it makes typehinting more useful for us
|
||||||
# and removes the need for constant None checks or asserts.
|
# and removes the need for constant None checks or asserts.
|
||||||
_state_info: StateInfo = None # type: ignore[assignment]
|
_state_info: StateInfo = None # type: ignore[assignment]
|
||||||
|
_is_custom_component: bool = False
|
||||||
|
|
||||||
__capabilities_updated_at: deque[float]
|
__capabilities_updated_at: deque[float]
|
||||||
__capabilities_updated_at_reported: bool = False
|
__capabilities_updated_at_reported: bool = False
|
||||||
|
@ -967,8 +968,8 @@ class Entity(
|
||||||
self._async_write_ha_state()
|
self._async_write_ha_state()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_write_ha_state(self) -> None:
|
def _async_verify_state_writable(self) -> None:
|
||||||
"""Write the state to the state machine."""
|
"""Verify the entity is in a writable state."""
|
||||||
if self.hass is None:
|
if self.hass is None:
|
||||||
raise RuntimeError(f"Attribute hass is None for {self}")
|
raise RuntimeError(f"Attribute hass is None for {self}")
|
||||||
if self.hass.config.debug:
|
if self.hass.config.debug:
|
||||||
|
@ -995,6 +996,18 @@ class Entity(
|
||||||
f"No entity id specified for entity {self.name}"
|
f"No entity id specified for entity {self.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_write_ha_state_from_call_soon_threadsafe(self) -> None:
|
||||||
|
"""Write the state to the state machine from the event loop thread."""
|
||||||
|
self._async_verify_state_writable()
|
||||||
|
self._async_write_ha_state()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_write_ha_state(self) -> None:
|
||||||
|
"""Write the state to the state machine."""
|
||||||
|
self._async_verify_state_writable()
|
||||||
|
if self._is_custom_component or self.hass.config.debug:
|
||||||
|
self.hass.verify_event_loop_thread("async_write_ha_state")
|
||||||
self._async_write_ha_state()
|
self._async_write_ha_state()
|
||||||
|
|
||||||
def _stringify_state(self, available: bool) -> str:
|
def _stringify_state(self, available: bool) -> str:
|
||||||
|
@ -1221,7 +1234,9 @@ class Entity(
|
||||||
f"Entity {self.entity_id} schedule update ha state",
|
f"Entity {self.entity_id} schedule update ha state",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.hass.loop.call_soon_threadsafe(self.async_write_ha_state)
|
self.hass.loop.call_soon_threadsafe(
|
||||||
|
self._async_write_ha_state_from_call_soon_threadsafe
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
|
def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
|
||||||
|
@ -1426,10 +1441,12 @@ class Entity(
|
||||||
|
|
||||||
Not to be extended by integrations.
|
Not to be extended by integrations.
|
||||||
"""
|
"""
|
||||||
|
is_custom_component = "custom_components" in type(self).__module__
|
||||||
entity_info: EntityInfo = {
|
entity_info: EntityInfo = {
|
||||||
"domain": self.platform.platform_name,
|
"domain": self.platform.platform_name,
|
||||||
"custom_component": "custom_components" in type(self).__module__,
|
"custom_component": is_custom_component,
|
||||||
}
|
}
|
||||||
|
self._is_custom_component = is_custom_component
|
||||||
|
|
||||||
if self.platform.config_entry:
|
if self.platform.config_entry:
|
||||||
entity_info["config_entry"] = self.platform.config_entry.entry_id
|
entity_info["config_entry"] = self.platform.config_entry.entry_id
|
||||||
|
|
|
@ -2615,3 +2615,29 @@ async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
|
||||||
):
|
):
|
||||||
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
||||||
assert not hass.states.get(ent2.entity_id)
|
assert not hass.states.get(ent2.entity_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_write_ha_state_thread_safety_custom_component(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_write_ha_state thread safe for custom components."""
|
||||||
|
|
||||||
|
ent = entity.Entity()
|
||||||
|
ent._is_custom_component = True
|
||||||
|
ent.entity_id = "test.any"
|
||||||
|
ent.hass = hass
|
||||||
|
ent.platform = MockEntityPlatform(hass, domain="test")
|
||||||
|
ent.async_write_ha_state()
|
||||||
|
assert hass.states.get(ent.entity_id)
|
||||||
|
|
||||||
|
ent2 = entity.Entity()
|
||||||
|
ent2._is_custom_component = True
|
||||||
|
ent2.entity_id = "test.any2"
|
||||||
|
ent2.hass = hass
|
||||||
|
ent2.platform = MockEntityPlatform(hass, domain="test")
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls async_write_ha_state from a thread.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
||||||
|
assert not hass.states.get(ent2.entity_id)
|
||||||
|
|
Loading…
Reference in New Issue