diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index de1f38f3e57..a43a863444a 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -21,6 +21,8 @@ SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( vol.Required("device_id"): str, vol.Optional("area_id"): vol.Any(str, None), vol.Optional("name_by_user"): vol.Any(str, None), + # We only allow setting disabled_by user via API. + vol.Optional("disabled_by"): vol.Any("user", None), } ) @@ -77,4 +79,5 @@ def _entry_dict(entry): "via_device_id": entry.via_device_id, "area_id": entry.area_id, "name_by_user": entry.name_by_user, + "disabled_by": entry.disabled_by, } diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 388db62ebae..cc8f9a17827 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -37,6 +37,9 @@ IDX_IDENTIFIERS = "identifiers" REGISTERED_DEVICE = "registered" DELETED_DEVICE = "deleted" +DISABLED_INTEGRATION = "integration" +DISABLED_USER = "user" + @attr.s(slots=True, frozen=True) class DeletedDeviceEntry: @@ -76,6 +79,21 @@ class DeviceEntry: id: str = attr.ib(factory=uuid_util.random_uuid_hex) # This value is not stored, just used to keep track of events to fire. is_new: bool = attr.ib(default=False) + disabled_by: Optional[str] = attr.ib( + default=None, + validator=attr.validators.in_( + ( + DISABLED_INTEGRATION, + DISABLED_USER, + None, + ) + ), + ) + + @property + def disabled(self) -> bool: + """Return if entry is disabled.""" + return self.disabled_by is not None def format_mac(mac: str) -> str: @@ -215,6 +233,8 @@ class DeviceRegistry: sw_version=_UNDEF, entry_type=_UNDEF, via_device=None, + # To disable a device if it gets created + disabled_by=_UNDEF, ): """Get device. Create if it doesn't exist.""" if not identifiers and not connections: @@ -267,6 +287,7 @@ class DeviceRegistry: name=name, sw_version=sw_version, entry_type=entry_type, + disabled_by=disabled_by, ) @callback @@ -283,6 +304,7 @@ class DeviceRegistry: sw_version=_UNDEF, via_device_id=_UNDEF, remove_config_entry_id=_UNDEF, + disabled_by=_UNDEF, ): """Update properties of a device.""" return self._async_update_device( @@ -296,6 +318,7 @@ class DeviceRegistry: sw_version=sw_version, via_device_id=via_device_id, remove_config_entry_id=remove_config_entry_id, + disabled_by=disabled_by, ) @callback @@ -316,6 +339,7 @@ class DeviceRegistry: via_device_id=_UNDEF, area_id=_UNDEF, name_by_user=_UNDEF, + disabled_by=_UNDEF, ): """Update device attributes.""" old = self.devices[device_id] @@ -362,6 +386,7 @@ class DeviceRegistry: ("sw_version", sw_version), ("entry_type", entry_type), ("via_device_id", via_device_id), + ("disabled_by", disabled_by), ): if value is not _UNDEF and value != getattr(old, attr_name): changes[attr_name] = value @@ -440,6 +465,8 @@ class DeviceRegistry: # Introduced in 0.87 area_id=device.get("area_id"), name_by_user=device.get("name_by_user"), + # Introduced in 0.119 + disabled_by=device.get("disabled_by"), ) # Introduced in 0.111 for device in data.get("deleted_devices", []): @@ -478,6 +505,7 @@ class DeviceRegistry: "via_device_id": entry.via_device_id, "area_id": entry.area_id, "name_by_user": entry.name_by_user, + "disabled_by": entry.disabled_by, } for entry in self.devices.values() ] diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 872d87e732f..143f3a99137 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -53,9 +53,10 @@ SAVE_DELAY = 10 _LOGGER = logging.getLogger(__name__) _UNDEF = object() DISABLED_CONFIG_ENTRY = "config_entry" +DISABLED_DEVICE = "device" DISABLED_HASS = "hass" -DISABLED_USER = "user" DISABLED_INTEGRATION = "integration" +DISABLED_USER = "user" STORAGE_VERSION = 1 STORAGE_KEY = "core.entity_registry" @@ -89,10 +90,11 @@ class RegistryEntry: default=None, validator=attr.validators.in_( ( - DISABLED_HASS, - DISABLED_USER, - DISABLED_INTEGRATION, DISABLED_CONFIG_ENTRY, + DISABLED_DEVICE, + DISABLED_HASS, + DISABLED_INTEGRATION, + DISABLED_USER, None, ) ), @@ -127,7 +129,7 @@ class EntityRegistry: self._index: Dict[Tuple[str, str, str], str] = {} self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self.hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed + EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified ) @callback @@ -286,18 +288,34 @@ class EntityRegistry: ) self.async_schedule_save() - @callback - def async_device_removed(self, event: Event) -> None: - """Handle the removal of a device. + async def async_device_modified(self, event: Event) -> None: + """Handle the removal or update of a device. Remove entities from the registry that are associated to a device when the device is removed. + + Disable entities in the registry that are associated to a device when + the device is disabled. """ - if event.data["action"] != "remove": + if event.data["action"] == "remove": + entities = async_entries_for_device(self, event.data["device_id"]) + for entity in entities: + self.async_remove(entity.entity_id) return + + if event.data["action"] != "update": + return + + device_registry = await self.hass.helpers.device_registry.async_get_registry() + device = device_registry.async_get(event.data["device_id"]) + if not device.disabled: + return + entities = async_entries_for_device(self, event.data["device_id"]) for entity in entities: - self.async_remove(entity.entity_id) + self.async_update_entity( # type: ignore + entity.entity_id, disabled_by=DISABLED_DEVICE + ) @callback def async_update_entity( diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index 1e1cbccf60a..b2273d640de 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -56,6 +56,7 @@ async def test_list_devices(hass, client, registry): "via_device_id": None, "area_id": None, "name_by_user": None, + "disabled_by": None, }, { "config_entries": ["1234"], @@ -69,6 +70,7 @@ async def test_list_devices(hass, client, registry): "via_device_id": dev1, "area_id": None, "name_by_user": None, + "disabled_by": None, }, ] @@ -92,6 +94,7 @@ async def test_update_device(hass, client, registry): "device_id": device.id, "area_id": "12345A", "name_by_user": "Test Friendly Name", + "disabled_by": "user", "type": "config/device_registry/update", } ) @@ -101,4 +104,5 @@ async def test_update_device(hass, client, registry): assert msg["result"]["id"] == device.id assert msg["result"]["area_id"] == "12345A" assert msg["result"]["name_by_user"] == "Test Friendly Name" + assert msg["result"]["disabled_by"] == "user" assert len(registry.devices) == 1 diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 85ff693f261..7fa787e023e 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -152,6 +152,7 @@ async def test_loading_from_storage(hass, hass_storage): "entry_type": "service", "area_id": "12345A", "name_by_user": "Test Friendly Name", + "disabled_by": "user", } ], "deleted_devices": [ @@ -180,6 +181,7 @@ async def test_loading_from_storage(hass, hass_storage): assert entry.area_id == "12345A" assert entry.name_by_user == "Test Friendly Name" assert entry.entry_type == "service" + assert entry.disabled_by == "user" assert isinstance(entry.config_entries, set) assert isinstance(entry.connections, set) assert isinstance(entry.identifiers, set) @@ -445,6 +447,7 @@ async def test_loading_saving_data(hass, registry): manufacturer="manufacturer", model="light", via_device=("hue", "0123"), + disabled_by="user", ) orig_light2 = registry.async_get_or_create( @@ -581,6 +584,7 @@ async def test_update(registry): name_by_user="Test Friendly Name", new_identifiers=new_identifiers, via_device_id="98765B", + disabled_by="user", ) assert mock_save.call_count == 1 @@ -591,6 +595,7 @@ async def test_update(registry): assert updated_entry.name_by_user == "Test Friendly Name" assert updated_entry.identifiers == new_identifiers assert updated_entry.via_device_id == "98765B" + assert updated_entry.disabled_by == "user" assert registry.async_get_device({("hue", "456")}, {}) is None assert registry.async_get_device({("bla", "123")}, {}) is None diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 336329396cc..f42661ec915 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -9,7 +9,12 @@ from homeassistant.helpers import entity_registry import tests.async_mock from tests.async_mock import patch -from tests.common import MockConfigEntry, flush_store, mock_registry +from tests.common import ( + MockConfigEntry, + flush_store, + mock_device_registry, + mock_registry, +) YAML__OPEN_PATH = "homeassistant.util.yaml.loader.open" @@ -677,3 +682,57 @@ async def test_async_get_device_class_lookup(hass): ("sensor", "battery"): "sensor.vacuum_battery", }, } + + +async def test_remove_device_removes_entities(hass, registry): + """Test that we remove entities tied to a device.""" + device_registry = mock_device_registry(hass) + config_entry = MockConfigEntry(domain="light") + + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + + entry = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=config_entry, + device_id=device_entry.id, + ) + + assert registry.async_is_registered(entry.entity_id) + + device_registry.async_remove_device(device_entry.id) + await hass.async_block_till_done() + + assert not registry.async_is_registered(entry.entity_id) + + +async def test_disable_device_disables_entities(hass, registry): + """Test that we remove entities tied to a device.""" + device_registry = mock_device_registry(hass) + config_entry = MockConfigEntry(domain="light") + + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + + entry = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=config_entry, + device_id=device_entry.id, + ) + + assert not entry.disabled + + device_registry.async_update_device(device_entry.id, disabled_by="user") + await hass.async_block_till_done() + + entry = registry.async_get(entry.entity_id) + assert entry.disabled + assert entry.disabled_by == "device"