Prevent polling from recreating an entity after removal (#67750)
parent
814c96834e
commit
f4ec7e0902
|
@ -6,6 +6,7 @@ import asyncio
|
|||
from collections.abc import Awaitable, Iterable, Mapping, MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum, auto
|
||||
import functools as ft
|
||||
import logging
|
||||
import math
|
||||
|
@ -207,6 +208,19 @@ class EntityCategory(StrEnum):
|
|||
SYSTEM = "system"
|
||||
|
||||
|
||||
class EntityPlatformState(Enum):
|
||||
"""The platform state of an entity."""
|
||||
|
||||
# Not Added: Not yet added to a platform, polling updates are written to the state machine
|
||||
NOT_ADDED = auto()
|
||||
|
||||
# Added: Added to a platform, polling updates are written to the state machine
|
||||
ADDED = auto()
|
||||
|
||||
# Removed: Removed from a platform, polling updates are not written to the state machine
|
||||
REMOVED = auto()
|
||||
|
||||
|
||||
def convert_to_entity_category(
|
||||
value: EntityCategory | str | None, raise_report: bool = True
|
||||
) -> EntityCategory | None:
|
||||
|
@ -294,7 +308,7 @@ class Entity(ABC):
|
|||
_context_set: datetime | None = None
|
||||
|
||||
# If entity is added to an entity platform
|
||||
_added = False
|
||||
_platform_state = EntityPlatformState.NOT_ADDED
|
||||
|
||||
# Entity Properties
|
||||
_attr_assumed_state: bool = False
|
||||
|
@ -553,6 +567,10 @@ class Entity(ABC):
|
|||
@callback
|
||||
def _async_write_ha_state(self) -> None:
|
||||
"""Write the state to the state machine."""
|
||||
if self._platform_state == EntityPlatformState.REMOVED:
|
||||
# Polling returned after the entity has already been removed
|
||||
return
|
||||
|
||||
if self.registry_entry and self.registry_entry.disabled_by:
|
||||
if not self._disabled_reported:
|
||||
self._disabled_reported = True
|
||||
|
@ -758,7 +776,7 @@ class Entity(ABC):
|
|||
parallel_updates: asyncio.Semaphore | None,
|
||||
) -> None:
|
||||
"""Start adding an entity to a platform."""
|
||||
if self._added:
|
||||
if self._platform_state == EntityPlatformState.ADDED:
|
||||
raise HomeAssistantError(
|
||||
f"Entity {self.entity_id} cannot be added a second time to an entity platform"
|
||||
)
|
||||
|
@ -766,7 +784,7 @@ class Entity(ABC):
|
|||
self.hass = hass
|
||||
self.platform = platform
|
||||
self.parallel_updates = parallel_updates
|
||||
self._added = True
|
||||
self._platform_state = EntityPlatformState.ADDED
|
||||
|
||||
@callback
|
||||
def add_to_platform_abort(self) -> None:
|
||||
|
@ -774,7 +792,7 @@ class Entity(ABC):
|
|||
self.hass = None # type: ignore[assignment]
|
||||
self.platform = None
|
||||
self.parallel_updates = None
|
||||
self._added = False
|
||||
self._platform_state = EntityPlatformState.NOT_ADDED
|
||||
|
||||
async def add_to_platform_finish(self) -> None:
|
||||
"""Finish adding an entity to a platform."""
|
||||
|
@ -792,12 +810,12 @@ class Entity(ABC):
|
|||
If the entity doesn't have a non disabled entry in the entity registry,
|
||||
or if force_remove=True, its state will be removed.
|
||||
"""
|
||||
if self.platform and not self._added:
|
||||
if self.platform and self._platform_state != EntityPlatformState.ADDED:
|
||||
raise HomeAssistantError(
|
||||
f"Entity {self.entity_id} async_remove called twice"
|
||||
)
|
||||
|
||||
self._added = False
|
||||
self._platform_state = EntityPlatformState.REMOVED
|
||||
|
||||
if self._on_remove is not None:
|
||||
while self._on_remove:
|
||||
|
|
|
@ -545,6 +545,22 @@ async def test_async_remove_runs_callbacks(hass):
|
|||
assert len(result) == 1
|
||||
|
||||
|
||||
async def test_async_remove_ignores_in_flight_polling(hass):
|
||||
"""Test in flight polling is ignored after removing."""
|
||||
result = []
|
||||
|
||||
ent = entity.Entity()
|
||||
ent.hass = hass
|
||||
ent.entity_id = "test.test"
|
||||
ent.async_on_remove(lambda: result.append(1))
|
||||
ent.async_write_ha_state()
|
||||
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
||||
await ent.async_remove()
|
||||
assert len(result) == 1
|
||||
assert hass.states.get("test.test") is None
|
||||
ent.async_write_ha_state()
|
||||
|
||||
|
||||
async def test_set_context(hass):
|
||||
"""Test setting context."""
|
||||
context = Context()
|
||||
|
|
|
@ -390,6 +390,30 @@ async def test_async_remove_with_platform(hass):
|
|||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
||||
|
||||
async def test_async_remove_with_platform_update_finishes(hass):
|
||||
"""Remove an entity when an update finishes after its been removed."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entity1 = MockEntity(name="test_1")
|
||||
|
||||
async def _delayed_update(*args, **kwargs):
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
entity1.async_update = _delayed_update
|
||||
|
||||
# Add, remove, add, remove and make sure no updates
|
||||
# cause the entity to reappear after removal
|
||||
for i in range(2):
|
||||
await component.async_add_entities([entity1])
|
||||
assert len(hass.states.async_entity_ids()) == 1
|
||||
entity1.async_write_ha_state()
|
||||
assert hass.states.get(entity1.entity_id) is not None
|
||||
task = asyncio.create_task(entity1.async_update_ha_state(True))
|
||||
await entity1.async_remove()
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
await task
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
||||
|
||||
async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog):
|
||||
"""Test for not adding duplicate entities."""
|
||||
caplog.set_level(logging.ERROR)
|
||||
|
|
Loading…
Reference in New Issue