Refactor button code to allow for other button types for UniFi Protect (#71911)
Co-authored-by: J. Nick Koston <nick@koston.org>pull/71996/head
parent
2d1a612976
commit
3de31939d8
|
@ -17,9 +17,11 @@ from homeassistant.const import (
|
|||
CONF_USERNAME,
|
||||
CONF_VERIFY_SSL,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.aiohttp_client import async_create_clientsession
|
||||
|
||||
from .const import (
|
||||
|
@ -27,6 +29,7 @@ from .const import (
|
|||
CONF_OVERRIDE_CHOST,
|
||||
DEFAULT_SCAN_INTERVAL,
|
||||
DEVICES_FOR_SUBSCRIBE,
|
||||
DEVICES_THAT_ADOPT,
|
||||
DOMAIN,
|
||||
MIN_REQUIRED_PROTECT_V,
|
||||
OUTDATED_LOG_MESSAGE,
|
||||
|
@ -41,6 +44,60 @@ _LOGGER = logging.getLogger(__name__)
|
|||
SCAN_INTERVAL = timedelta(seconds=DEFAULT_SCAN_INTERVAL)
|
||||
|
||||
|
||||
async def _async_migrate_data(
|
||||
hass: HomeAssistant, entry: ConfigEntry, protect: ProtectApiClient
|
||||
) -> None:
|
||||
|
||||
registry = er.async_get(hass)
|
||||
to_migrate = []
|
||||
for entity in er.async_entries_for_config_entry(registry, entry.entry_id):
|
||||
if entity.domain == Platform.BUTTON and "_" not in entity.unique_id:
|
||||
_LOGGER.debug("Button %s needs migration", entity.entity_id)
|
||||
to_migrate.append(entity)
|
||||
|
||||
if len(to_migrate) == 0:
|
||||
_LOGGER.debug("No entities need migration")
|
||||
return
|
||||
|
||||
_LOGGER.info("Migrating %s reboot button entities ", len(to_migrate))
|
||||
bootstrap = await protect.get_bootstrap()
|
||||
count = 0
|
||||
for button in to_migrate:
|
||||
device = None
|
||||
for model in DEVICES_THAT_ADOPT:
|
||||
attr = f"{model.value}s"
|
||||
device = getattr(bootstrap, attr).get(button.unique_id)
|
||||
if device is not None:
|
||||
break
|
||||
|
||||
if device is None:
|
||||
continue
|
||||
|
||||
new_unique_id = f"{device.id}_reboot"
|
||||
_LOGGER.debug(
|
||||
"Migrating entity %s (old unique_id: %s, new unique_id: %s)",
|
||||
button.entity_id,
|
||||
button.unique_id,
|
||||
new_unique_id,
|
||||
)
|
||||
try:
|
||||
registry.async_update_entity(button.entity_id, new_unique_id=new_unique_id)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Could not migrate entity %s (old unique_id: %s, new unique_id: %s)",
|
||||
button.entity_id,
|
||||
button.unique_id,
|
||||
new_unique_id,
|
||||
)
|
||||
else:
|
||||
count += 1
|
||||
|
||||
if count < len(to_migrate):
|
||||
_LOGGER.warning("Failed to migate %s reboot buttons", len(to_migrate) - count)
|
||||
else:
|
||||
_LOGGER.info("Migrated %s reboot button entities", count)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up the UniFi Protect config entries."""
|
||||
|
||||
|
@ -75,6 +132,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
)
|
||||
return False
|
||||
|
||||
await _async_migrate_data(hass, entry, protect)
|
||||
if entry.unique_id is None:
|
||||
hass.config_entries.async_update_entry(entry, unique_id=nvr_info.mac)
|
||||
|
||||
|
|
|
@ -1,20 +1,47 @@
|
|||
"""Support for Ubiquiti's UniFi Protect NVR."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel
|
||||
|
||||
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
|
||||
from homeassistant.components.button import (
|
||||
ButtonDeviceClass,
|
||||
ButtonEntity,
|
||||
ButtonEntityDescription,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DEVICES_THAT_ADOPT, DOMAIN
|
||||
from .const import DOMAIN
|
||||
from .data import ProtectData
|
||||
from .entity import ProtectDeviceEntity
|
||||
from .entity import ProtectDeviceEntity, async_all_device_entities
|
||||
from .models import ProtectSetableKeysMixin, T
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ProtectButtonEntityDescription(
|
||||
ProtectSetableKeysMixin[T], ButtonEntityDescription
|
||||
):
|
||||
"""Describes UniFi Protect Button entity."""
|
||||
|
||||
ufp_press: str | None = None
|
||||
|
||||
|
||||
DEVICE_CLASS_CHIME_BUTTON: Final = "unifiprotect__chime_button"
|
||||
|
||||
|
||||
ALL_DEVICE_BUTTONS: tuple[ProtectButtonEntityDescription, ...] = (
|
||||
ProtectButtonEntityDescription(
|
||||
key="reboot",
|
||||
entity_registry_enabled_default=False,
|
||||
device_class=ButtonDeviceClass.RESTART,
|
||||
name="Reboot Device",
|
||||
ufp_press="reboot",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
|
@ -25,34 +52,30 @@ async def async_setup_entry(
|
|||
"""Discover devices on a UniFi Protect NVR."""
|
||||
data: ProtectData = hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
ProtectButton(
|
||||
data,
|
||||
device,
|
||||
)
|
||||
for device in data.get_by_types(DEVICES_THAT_ADOPT)
|
||||
]
|
||||
entities: list[ProtectDeviceEntity] = async_all_device_entities(
|
||||
data, ProtectButton, all_descs=ALL_DEVICE_BUTTONS
|
||||
)
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class ProtectButton(ProtectDeviceEntity, ButtonEntity):
|
||||
"""A Ubiquiti UniFi Protect Reboot button."""
|
||||
|
||||
_attr_entity_registry_enabled_default = False
|
||||
_attr_device_class = ButtonDeviceClass.RESTART
|
||||
entity_description: ProtectButtonEntityDescription
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: ProtectData,
|
||||
device: ProtectAdoptableDeviceModel,
|
||||
description: ProtectButtonEntityDescription,
|
||||
) -> None:
|
||||
"""Initialize an UniFi camera."""
|
||||
super().__init__(data, device)
|
||||
self._attr_name = f"{self.device.name} Reboot Device"
|
||||
super().__init__(data, device, description)
|
||||
self._attr_name = f"{self.device.name} {self.entity_description.name}"
|
||||
|
||||
async def async_press(self) -> None:
|
||||
"""Press the button."""
|
||||
|
||||
_LOGGER.debug("Rebooting %s with id %s", self.device.model, self.device.id)
|
||||
await self.device.reboot()
|
||||
if self.entity_description.ufp_press is not None:
|
||||
await getattr(self.device, self.entity_description.ufp_press)()
|
||||
|
|
|
@ -49,7 +49,7 @@ async def test_button(
|
|||
|
||||
mock_entry.api.reboot_device = AsyncMock()
|
||||
|
||||
unique_id = f"{camera[0].id}"
|
||||
unique_id = f"{camera[0].id}_reboot"
|
||||
entity_id = camera[1]
|
||||
|
||||
entity_registry = er.async_get(hass)
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
"""Test the UniFi Protect setup flow."""
|
||||
# pylint: disable=protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pyunifiprotect import NotAuthorized, NvrError
|
||||
from pyunifiprotect.data import NVR
|
||||
from pyunifiprotect.data import NVR, Light
|
||||
|
||||
from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from . import _patch_discovery
|
||||
from .conftest import MockBootstrap, MockEntityFixture
|
||||
|
@ -175,3 +178,103 @@ async def test_setup_starts_discovery(
|
|||
assert mock_entry.entry.state == ConfigEntryState.LOADED
|
||||
await hass.async_block_till_done()
|
||||
assert len(hass.config_entries.flow.async_progress_by_handler(DOMAIN)) == 1
|
||||
|
||||
|
||||
async def test_migrate_reboot_button(
|
||||
hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light
|
||||
):
|
||||
"""Test migrating unique ID of reboot button."""
|
||||
|
||||
light1 = mock_light.copy()
|
||||
light1._api = mock_entry.api
|
||||
light1.name = "Test Light 1"
|
||||
light1.id = "lightid1"
|
||||
|
||||
light2 = mock_light.copy()
|
||||
light2._api = mock_entry.api
|
||||
light2.name = "Test Light 2"
|
||||
light2.id = "lightid2"
|
||||
mock_entry.api.bootstrap.lights = {
|
||||
light1.id: light1,
|
||||
light2.id: light2,
|
||||
}
|
||||
mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap)
|
||||
|
||||
registry = er.async_get(hass)
|
||||
registry.async_get_or_create(
|
||||
Platform.BUTTON, Platform.BUTTON, light1.id, config_entry=mock_entry.entry
|
||||
)
|
||||
registry.async_get_or_create(
|
||||
Platform.BUTTON,
|
||||
Platform.BUTTON,
|
||||
f"{light2.id}_reboot",
|
||||
config_entry=mock_entry.entry,
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_entry.entry.state == ConfigEntryState.LOADED
|
||||
assert mock_entry.api.update.called
|
||||
assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac
|
||||
|
||||
assert registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device_2") is None
|
||||
light = registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device")
|
||||
assert light is not None
|
||||
assert light.unique_id == f"{light1.id}_reboot"
|
||||
|
||||
assert registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device_2") is None
|
||||
light = registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device")
|
||||
assert light is not None
|
||||
assert light.unique_id == f"{light2.id}_reboot"
|
||||
|
||||
buttons = []
|
||||
for entity in er.async_entries_for_config_entry(
|
||||
registry, mock_entry.entry.entry_id
|
||||
):
|
||||
if entity.platform == Platform.BUTTON.value:
|
||||
buttons.append(entity)
|
||||
assert len(buttons) == 2
|
||||
|
||||
|
||||
async def test_migrate_reboot_button_fail(
|
||||
hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light
|
||||
):
|
||||
"""Test migrating unique ID of reboot button."""
|
||||
|
||||
light1 = mock_light.copy()
|
||||
light1._api = mock_entry.api
|
||||
light1.name = "Test Light 1"
|
||||
light1.id = "lightid1"
|
||||
|
||||
mock_entry.api.bootstrap.lights = {
|
||||
light1.id: light1,
|
||||
}
|
||||
mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap)
|
||||
|
||||
registry = er.async_get(hass)
|
||||
registry.async_get_or_create(
|
||||
Platform.BUTTON,
|
||||
Platform.BUTTON,
|
||||
light1.id,
|
||||
config_entry=mock_entry.entry,
|
||||
suggested_object_id=light1.name,
|
||||
)
|
||||
registry.async_get_or_create(
|
||||
Platform.BUTTON,
|
||||
Platform.BUTTON,
|
||||
f"{light1.id}_reboot",
|
||||
config_entry=mock_entry.entry,
|
||||
suggested_object_id=light1.name,
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_entry.entry.state == ConfigEntryState.LOADED
|
||||
assert mock_entry.api.update.called
|
||||
assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac
|
||||
|
||||
light = registry.async_get(f"{Platform.BUTTON}.test_light_1")
|
||||
assert light is not None
|
||||
assert light.unique_id == f"{light1.id}"
|
||||
|
|
Loading…
Reference in New Issue