Refactor button code to allow for other button types for UniFi Protect (#71911)

Co-authored-by: J. Nick Koston <nick@koston.org>
pull/71996/head
Christopher Bailey 2022-05-16 23:51:13 -04:00 committed by GitHub
parent 2d1a612976
commit 3de31939d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 21 deletions

View File

@ -17,9 +17,11 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
CONF_VERIFY_SSL, CONF_VERIFY_SSL,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
Platform,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.helpers.aiohttp_client import async_create_clientsession
from .const import ( from .const import (
@ -27,6 +29,7 @@ from .const import (
CONF_OVERRIDE_CHOST, CONF_OVERRIDE_CHOST,
DEFAULT_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
DEVICES_FOR_SUBSCRIBE, DEVICES_FOR_SUBSCRIBE,
DEVICES_THAT_ADOPT,
DOMAIN, DOMAIN,
MIN_REQUIRED_PROTECT_V, MIN_REQUIRED_PROTECT_V,
OUTDATED_LOG_MESSAGE, OUTDATED_LOG_MESSAGE,
@ -41,6 +44,60 @@ _LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=DEFAULT_SCAN_INTERVAL) SCAN_INTERVAL = timedelta(seconds=DEFAULT_SCAN_INTERVAL)
async def _async_migrate_data(
hass: HomeAssistant, entry: ConfigEntry, protect: ProtectApiClient
) -> None:
registry = er.async_get(hass)
to_migrate = []
for entity in er.async_entries_for_config_entry(registry, entry.entry_id):
if entity.domain == Platform.BUTTON and "_" not in entity.unique_id:
_LOGGER.debug("Button %s needs migration", entity.entity_id)
to_migrate.append(entity)
if len(to_migrate) == 0:
_LOGGER.debug("No entities need migration")
return
_LOGGER.info("Migrating %s reboot button entities ", len(to_migrate))
bootstrap = await protect.get_bootstrap()
count = 0
for button in to_migrate:
device = None
for model in DEVICES_THAT_ADOPT:
attr = f"{model.value}s"
device = getattr(bootstrap, attr).get(button.unique_id)
if device is not None:
break
if device is None:
continue
new_unique_id = f"{device.id}_reboot"
_LOGGER.debug(
"Migrating entity %s (old unique_id: %s, new unique_id: %s)",
button.entity_id,
button.unique_id,
new_unique_id,
)
try:
registry.async_update_entity(button.entity_id, new_unique_id=new_unique_id)
except ValueError:
_LOGGER.warning(
"Could not migrate entity %s (old unique_id: %s, new unique_id: %s)",
button.entity_id,
button.unique_id,
new_unique_id,
)
else:
count += 1
if count < len(to_migrate):
_LOGGER.warning("Failed to migate %s reboot buttons", len(to_migrate) - count)
else:
_LOGGER.info("Migrated %s reboot button entities", count)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the UniFi Protect config entries.""" """Set up the UniFi Protect config entries."""
@ -75,6 +132,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
return False return False
await _async_migrate_data(hass, entry, protect)
if entry.unique_id is None: if entry.unique_id is None:
hass.config_entries.async_update_entry(entry, unique_id=nvr_info.mac) hass.config_entries.async_update_entry(entry, unique_id=nvr_info.mac)

View File

@ -1,20 +1,47 @@
"""Support for Ubiquiti's UniFi Protect NVR.""" """Support for Ubiquiti's UniFi Protect NVR."""
from __future__ import annotations from __future__ import annotations
import logging from dataclasses import dataclass
from typing import Final
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel from pyunifiprotect.data.base import ProtectAdoptableDeviceModel
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity from homeassistant.components.button import (
ButtonDeviceClass,
ButtonEntity,
ButtonEntityDescription,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DEVICES_THAT_ADOPT, DOMAIN from .const import DOMAIN
from .data import ProtectData from .data import ProtectData
from .entity import ProtectDeviceEntity from .entity import ProtectDeviceEntity, async_all_device_entities
from .models import ProtectSetableKeysMixin, T
_LOGGER = logging.getLogger(__name__)
@dataclass
class ProtectButtonEntityDescription(
ProtectSetableKeysMixin[T], ButtonEntityDescription
):
"""Describes UniFi Protect Button entity."""
ufp_press: str | None = None
DEVICE_CLASS_CHIME_BUTTON: Final = "unifiprotect__chime_button"
ALL_DEVICE_BUTTONS: tuple[ProtectButtonEntityDescription, ...] = (
ProtectButtonEntityDescription(
key="reboot",
entity_registry_enabled_default=False,
device_class=ButtonDeviceClass.RESTART,
name="Reboot Device",
ufp_press="reboot",
),
)
async def async_setup_entry( async def async_setup_entry(
@ -25,34 +52,30 @@ async def async_setup_entry(
"""Discover devices on a UniFi Protect NVR.""" """Discover devices on a UniFi Protect NVR."""
data: ProtectData = hass.data[DOMAIN][entry.entry_id] data: ProtectData = hass.data[DOMAIN][entry.entry_id]
async_add_entities( entities: list[ProtectDeviceEntity] = async_all_device_entities(
[ data, ProtectButton, all_descs=ALL_DEVICE_BUTTONS
ProtectButton(
data,
device,
)
for device in data.get_by_types(DEVICES_THAT_ADOPT)
]
) )
async_add_entities(entities)
class ProtectButton(ProtectDeviceEntity, ButtonEntity): class ProtectButton(ProtectDeviceEntity, ButtonEntity):
"""A Ubiquiti UniFi Protect Reboot button.""" """A Ubiquiti UniFi Protect Reboot button."""
_attr_entity_registry_enabled_default = False entity_description: ProtectButtonEntityDescription
_attr_device_class = ButtonDeviceClass.RESTART
def __init__( def __init__(
self, self,
data: ProtectData, data: ProtectData,
device: ProtectAdoptableDeviceModel, device: ProtectAdoptableDeviceModel,
description: ProtectButtonEntityDescription,
) -> None: ) -> None:
"""Initialize an UniFi camera.""" """Initialize an UniFi camera."""
super().__init__(data, device) super().__init__(data, device, description)
self._attr_name = f"{self.device.name} Reboot Device" self._attr_name = f"{self.device.name} {self.entity_description.name}"
async def async_press(self) -> None: async def async_press(self) -> None:
"""Press the button.""" """Press the button."""
_LOGGER.debug("Rebooting %s with id %s", self.device.model, self.device.id) if self.entity_description.ufp_press is not None:
await self.device.reboot() await getattr(self.device, self.entity_description.ufp_press)()

View File

@ -49,7 +49,7 @@ async def test_button(
mock_entry.api.reboot_device = AsyncMock() mock_entry.api.reboot_device = AsyncMock()
unique_id = f"{camera[0].id}" unique_id = f"{camera[0].id}_reboot"
entity_id = camera[1] entity_id = camera[1]
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)

View File

@ -1,14 +1,17 @@
"""Test the UniFi Protect setup flow.""" """Test the UniFi Protect setup flow."""
# pylint: disable=protected-access
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from pyunifiprotect import NotAuthorized, NvrError from pyunifiprotect import NotAuthorized, NvrError
from pyunifiprotect.data import NVR from pyunifiprotect.data import NVR, Light
from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAIN from homeassistant.components.unifiprotect.const import CONF_DISABLE_RTSP, DOMAIN
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from . import _patch_discovery from . import _patch_discovery
from .conftest import MockBootstrap, MockEntityFixture from .conftest import MockBootstrap, MockEntityFixture
@ -175,3 +178,103 @@ async def test_setup_starts_discovery(
assert mock_entry.entry.state == ConfigEntryState.LOADED assert mock_entry.entry.state == ConfigEntryState.LOADED
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress_by_handler(DOMAIN)) == 1 assert len(hass.config_entries.flow.async_progress_by_handler(DOMAIN)) == 1
async def test_migrate_reboot_button(
hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light
):
"""Test migrating unique ID of reboot button."""
light1 = mock_light.copy()
light1._api = mock_entry.api
light1.name = "Test Light 1"
light1.id = "lightid1"
light2 = mock_light.copy()
light2._api = mock_entry.api
light2.name = "Test Light 2"
light2.id = "lightid2"
mock_entry.api.bootstrap.lights = {
light1.id: light1,
light2.id: light2,
}
mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap)
registry = er.async_get(hass)
registry.async_get_or_create(
Platform.BUTTON, Platform.BUTTON, light1.id, config_entry=mock_entry.entry
)
registry.async_get_or_create(
Platform.BUTTON,
Platform.BUTTON,
f"{light2.id}_reboot",
config_entry=mock_entry.entry,
)
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
await hass.async_block_till_done()
assert mock_entry.entry.state == ConfigEntryState.LOADED
assert mock_entry.api.update.called
assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac
assert registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device_2") is None
light = registry.async_get(f"{Platform.BUTTON}.test_light_1_reboot_device")
assert light is not None
assert light.unique_id == f"{light1.id}_reboot"
assert registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device_2") is None
light = registry.async_get(f"{Platform.BUTTON}.test_light_2_reboot_device")
assert light is not None
assert light.unique_id == f"{light2.id}_reboot"
buttons = []
for entity in er.async_entries_for_config_entry(
registry, mock_entry.entry.entry_id
):
if entity.platform == Platform.BUTTON.value:
buttons.append(entity)
assert len(buttons) == 2
async def test_migrate_reboot_button_fail(
hass: HomeAssistant, mock_entry: MockEntityFixture, mock_light: Light
):
"""Test migrating unique ID of reboot button."""
light1 = mock_light.copy()
light1._api = mock_entry.api
light1.name = "Test Light 1"
light1.id = "lightid1"
mock_entry.api.bootstrap.lights = {
light1.id: light1,
}
mock_entry.api.get_bootstrap = AsyncMock(return_value=mock_entry.api.bootstrap)
registry = er.async_get(hass)
registry.async_get_or_create(
Platform.BUTTON,
Platform.BUTTON,
light1.id,
config_entry=mock_entry.entry,
suggested_object_id=light1.name,
)
registry.async_get_or_create(
Platform.BUTTON,
Platform.BUTTON,
f"{light1.id}_reboot",
config_entry=mock_entry.entry,
suggested_object_id=light1.name,
)
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
await hass.async_block_till_done()
assert mock_entry.entry.state == ConfigEntryState.LOADED
assert mock_entry.api.update.called
assert mock_entry.entry.unique_id == mock_entry.api.bootstrap.nvr.mac
light = registry.async_get(f"{Platform.BUTTON}.test_light_1")
assert light is not None
assert light.unique_id == f"{light1.id}"