diff --git a/homeassistant/components/unifiprotect/__init__.py b/homeassistant/components/unifiprotect/__init__.py index 97fdc6eac20..c28f2639e00 100644 --- a/homeassistant/components/unifiprotect/__init__.py +++ b/homeassistant/components/unifiprotect/__init__.py @@ -17,9 +17,11 @@ from homeassistant.const import ( CONF_USERNAME, CONF_VERIFY_SSL, EVENT_HOMEASSISTANT_STOP, + Platform, ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.aiohttp_client import async_create_clientsession from .const import ( @@ -27,6 +29,7 @@ from .const import ( CONF_OVERRIDE_CHOST, DEFAULT_SCAN_INTERVAL, DEVICES_FOR_SUBSCRIBE, + DEVICES_THAT_ADOPT, DOMAIN, MIN_REQUIRED_PROTECT_V, OUTDATED_LOG_MESSAGE, @@ -41,6 +44,60 @@ _LOGGER = logging.getLogger(__name__) SCAN_INTERVAL = timedelta(seconds=DEFAULT_SCAN_INTERVAL) +async def _async_migrate_data( + hass: HomeAssistant, entry: ConfigEntry, protect: ProtectApiClient +) -> None: + + registry = er.async_get(hass) + to_migrate = [] + for entity in er.async_entries_for_config_entry(registry, entry.entry_id): + if entity.domain == Platform.BUTTON and "_" not in entity.unique_id: + _LOGGER.debug("Button %s needs migration", entity.entity_id) + to_migrate.append(entity) + + if len(to_migrate) == 0: + _LOGGER.debug("No entities need migration") + return + + _LOGGER.info("Migrating %s reboot button entities ", len(to_migrate)) + bootstrap = await protect.get_bootstrap() + count = 0 + for button in to_migrate: + device = None + for model in DEVICES_THAT_ADOPT: + attr = f"{model.value}s" + device = getattr(bootstrap, attr).get(button.unique_id) + if device is not None: + break + + if device is None: + continue + + new_unique_id = f"{device.id}_reboot" + _LOGGER.debug( + "Migrating entity %s (old unique_id: %s, new unique_id: %s)", + button.entity_id, + button.unique_id, + new_unique_id, + ) + try: + registry.async_update_entity(button.entity_id, new_unique_id=new_unique_id) + except ValueError: + _LOGGER.warning( + "Could not migrate entity %s (old unique_id: %s, new unique_id: %s)", + button.entity_id, + button.unique_id, + new_unique_id, + ) + else: + count += 1 + + if count < len(to_migrate): + _LOGGER.warning("Failed to migate %s reboot buttons", len(to_migrate) - count) + else: + _LOGGER.info("Migrated %s reboot button entities", count) + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up the UniFi Protect config entries.""" @@ -75,6 +132,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) return False + await _async_migrate_data(hass, entry, protect) if entry.unique_id is None: hass.config_entries.async_update_entry(entry, unique_id=nvr_info.mac) diff --git a/homeassistant/components/unifiprotect/button.py b/homeassistant/components/unifiprotect/button.py index 3940c85d21a..731b1eaf86a 100644 --- a/homeassistant/components/unifiprotect/button.py +++ b/homeassistant/components/unifiprotect/button.py @@ -1,20 +1,47 @@ """Support for Ubiquiti's UniFi Protect NVR.""" from __future__ import annotations -import logging +from dataclasses import dataclass +from typing import Final from pyunifiprotect.data.base import ProtectAdoptableDeviceModel -from homeassistant.components.button import ButtonDeviceClass, ButtonEntity +from homeassistant.components.button import ( + ButtonDeviceClass, + ButtonEntity, + ButtonEntityDescription, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import DEVICES_THAT_ADOPT, DOMAIN +from .const import DOMAIN from .data import ProtectData -from .entity import ProtectDeviceEntity +from .entity import ProtectDeviceEntity, async_all_device_entities +from .models import ProtectSetableKeysMixin, T -_LOGGER = logging.getLogger(__name__) + +@dataclass +class ProtectButtonEntityDescription( + ProtectSetableKeysMixin[T], ButtonEntityDescription +): + """Describes UniFi Protect Button entity.""" + + ufp_press: str | None = None + + +DEVICE_CLASS_CHIME_BUTTON: Final = "unifiprotect__chime_button" + + +ALL_DEVICE_BUTTONS: tuple[ProtectButtonEntityDescription, ...] = ( + ProtectButtonEntityDescription( + key="reboot", + entity_registry_enabled_default=False, + device_class=ButtonDeviceClass.RESTART, + name="Reboot Device", + ufp_press="reboot", + ), +) async def async_setup_entry( @@ -25,34 +52,30 @@ async def async_setup_entry( """Discover devices on a UniFi Protect NVR.""" data: ProtectData = hass.data[DOMAIN][entry.entry_id] - async_add_entities( - [ - ProtectButton( - data, - device, - ) - for device in data.get_by_types(DEVICES_THAT_ADOPT) - ] + entities: list[ProtectDeviceEntity] = async_all_device_entities( + data, ProtectButton, all_descs=ALL_DEVICE_BUTTONS ) + async_add_entities(entities) + class ProtectButton(ProtectDeviceEntity, ButtonEntity): """A Ubiquiti UniFi Protect Reboot button.""" - _attr_entity_registry_enabled_default = False - _attr_device_class = ButtonDeviceClass.RESTART + entity_description: ProtectButtonEntityDescription def __init__( self, data: ProtectData, device: ProtectAdoptableDeviceModel, + description: ProtectButtonEntityDescription, ) -> None: """Initialize an UniFi camera.""" - super().__init__(data, device) - self._attr_name = f"{self.device.name} Reboot Device" + super().__init__(data, device, description) + self._attr_name = f"{self.device.name} {self.entity_description.name}" async def async_press(self) -> None: """Press the button.""" - _LOGGER.debug("Rebooting %s with id %s", self.device.model, self.device.id) - await self.device.reboot() + if self.entity_description.ufp_press is not None: + await getattr(self.device, self.entity_description.ufp_press)() diff --git a/tests/components/unifiprotect/test_button.py b/tests/components/unifiprotect/test_button.py index 0064781c6ce..64677dd1d77 100644 --- a/tests/components/unifiprotect/test_button.py +++ b/tests/components/unifiprotect/test_button.py @@ -49,7 +49,7 @@ async def test_button( mock_entry.api.reboot_device = AsyncMock() - unique_id = f"{camera[0].id}" + unique_id = f"{camera[0].id}_reboot" entity_id = camera[1] entity_registry = er.async_get(hass) diff --git a/tests/components/unifiprotect/test_init.py b/tests/components/unifiprotect/test_init.py index 77bf900d87e..53588984e25 100644 --- a/tests/components/unifiprotect/test_init.py +++ b/tests/components/unifiprotect/test_init.py @@ -1,14 +1,17 @@ """Test the UniFi Protect setup flow.""" +# pylint: disable=protected-access from __future__ import annotations from unittest.mock import AsyncMock, patch from pyunifiprotect import NotAuthorized, NvrError -from pyunifiprotect.data import NVR +from pyunifiprotect.data import NVR, Light from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAIN from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.const import Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from . import _patch_discovery from .conftest import MockBootstrap, MockEntityFixture @@ -175,3 +178,103 @@ async def test_setup_starts_discovery( assert mock_entry.entry.state == ConfigEntryState.LOADED await hass.async_block_till_done() assert len(hass.config_entries.flow.async_progress_by_handler(DOMAIN)) == 1 + + +async def test_migrate_reboot_button( + hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light +): + """Test migrating unique ID of reboot button.""" + + light1 = mock_light.copy() + light1._api = mock_entry.api + light1.name = "Test Light 1" + light1.id = "lightid1" + + light2 = mock_light.copy() + light2._api = mock_entry.api + light2.name = "Test Light 2" + light2.id = "lightid2" + mock_entry.api.bootstrap.lights = { + light1.id: light1, + light2.id: light2, + } + mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap) + + registry = er.async_get(hass) + registry.async_get_or_create( + Platform.BUTTON, Platform.BUTTON, light1.id, config_entry=mock_entry.entry + ) + registry.async_get_or_create( + Platform.BUTTON, + Platform.BUTTON, + f"{light2.id}_reboot", + config_entry=mock_entry.entry, + ) + + await hass.config_entries.async_setup(mock_entry.entry.entry_id) + await hass.async_block_till_done() + + assert mock_entry.entry.state == ConfigEntryState.LOADED + assert mock_entry.api.update.called + assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac + + assert registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device_2") is None + light = registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device") + assert light is not None + assert light.unique_id == f"{light1.id}_reboot" + + assert registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device_2") is None + light = registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device") + assert light is not None + assert light.unique_id == f"{light2.id}_reboot" + + buttons = [] + for entity in er.async_entries_for_config_entry( + registry, mock_entry.entry.entry_id + ): + if entity.platform == Platform.BUTTON.value: + buttons.append(entity) + assert len(buttons) == 2 + + +async def test_migrate_reboot_button_fail( + hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light +): + """Test migrating unique ID of reboot button.""" + + light1 = mock_light.copy() + light1._api = mock_entry.api + light1.name = "Test Light 1" + light1.id = "lightid1" + + mock_entry.api.bootstrap.lights = { + light1.id: light1, + } + mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap) + + registry = er.async_get(hass) + registry.async_get_or_create( + Platform.BUTTON, + Platform.BUTTON, + light1.id, + config_entry=mock_entry.entry, + suggested_object_id=light1.name, + ) + registry.async_get_or_create( + Platform.BUTTON, + Platform.BUTTON, + f"{light1.id}_reboot", + config_entry=mock_entry.entry, + suggested_object_id=light1.name, + ) + + await hass.config_entries.async_setup(mock_entry.entry.entry_id) + await hass.async_block_till_done() + + assert mock_entry.entry.state == ConfigEntryState.LOADED + assert mock_entry.api.update.called + assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac + + light = registry.async_get(f"{Platform.BUTTON}.test_light_1") + assert light is not None + assert light.unique_id == f"{light1.id}"