Add reboot button to Shelly devices (#60417)

pull/60557/head
Michael 2021-11-29 19:49:49 +01:00 committed by GitHub
parent 814a742518
commit 83acfda757
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 55 deletions

View File

@ -1,9 +1,11 @@
"""Button for Shelly."""
from __future__ import annotations
from typing import cast
from collections.abc import Callable
from dataclasses import dataclass
from typing import Final, cast
from homeassistant.components.button import ButtonEntity
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ENTITY_CATEGORY_CONFIG
from homeassistant.core import HomeAssistant
@ -17,6 +19,44 @@ from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC
from .utils import get_block_device_name, get_device_entry_gen, get_rpc_device_name
@dataclass
class ShellyButtonDescriptionMixin:
"""Mixin to describe a Button entity."""
press_action: Callable
@dataclass
class ShellyButtonDescription(ButtonEntityDescription, ShellyButtonDescriptionMixin):
"""Class to describe a Button entity."""
BUTTONS: Final = [
ShellyButtonDescription(
key="ota_update",
name="OTA Update",
icon="mdi:package-up",
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.async_trigger_ota_update(),
),
ShellyButtonDescription(
key="ota_update_beta",
name="OTA Update Beta",
icon="mdi:flask-outline",
entity_registry_enabled_default=False,
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.async_trigger_ota_update(beta=True),
),
ShellyButtonDescription(
key="reboot",
name="Reboot",
icon="mdi:restart",
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.device.trigger_reboot(),
),
]
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
@ -36,66 +76,34 @@ async def async_setup_entry(
wrapper = cast(BlockDeviceWrapper, block_wrapper)
if wrapper is not None:
async_add_entities(
[
ShellyOtaUpdateStableButton(wrapper, config_entry),
ShellyOtaUpdateBetaButton(wrapper, config_entry),
]
)
async_add_entities([ShellyButton(wrapper, button) for button in BUTTONS])
class ShellyOtaUpdateBaseButton(ButtonEntity):
class ShellyButton(ButtonEntity):
"""Defines a Shelly OTA update base button."""
_attr_entity_category = ENTITY_CATEGORY_CONFIG
entity_description: ShellyButtonDescription
def __init__(
self,
wrapper: RpcDeviceWrapper | BlockDeviceWrapper,
entry: ConfigEntry,
name: str,
beta_channel: bool,
icon: str,
description: ShellyButtonDescription,
) -> None:
"""Initialize Shelly OTA update button."""
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, wrapper.mac)}
)
self.entity_description = description
self.wrapper = wrapper
if isinstance(wrapper, RpcDeviceWrapper):
device_name = get_rpc_device_name(wrapper.device)
else:
device_name = get_block_device_name(wrapper.device)
self._attr_name = f"{device_name} {name}"
self._attr_name = f"{device_name} {description.name}"
self._attr_unique_id = slugify(self._attr_name)
self._attr_icon = icon
self.beta_channel = beta_channel
self.entry = entry
self.wrapper = wrapper
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, wrapper.mac)}
)
async def async_press(self) -> None:
"""Triggers the OTA update service."""
await self.wrapper.async_trigger_ota_update(beta=self.beta_channel)
class ShellyOtaUpdateStableButton(ShellyOtaUpdateBaseButton):
"""Defines a Shelly OTA update stable channel button."""
def __init__(
self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry
) -> None:
"""Initialize Shelly OTA update button."""
super().__init__(wrapper, entry, "OTA Update", False, "mdi:package-up")
class ShellyOtaUpdateBetaButton(ShellyOtaUpdateBaseButton):
"""Defines a Shelly OTA update beta channel button."""
def __init__(
self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry
) -> None:
"""Initialize Shelly OTA update button."""
super().__init__(wrapper, entry, "OTA Update Beta", True, "mdi:flask-outline")
self._attr_entity_registry_enabled_default = False
await self.entity_description.press_action(self.wrapper)

View File

@ -138,6 +138,7 @@ async def coap_wrapper(hass):
firmware_version="some fw string",
update=AsyncMock(),
trigger_ota_update=AsyncMock(),
trigger_reboot=AsyncMock(),
initialized=True,
)
@ -173,6 +174,7 @@ async def rpc_wrapper(hass):
firmware_version="some fw string",
update=AsyncMock(),
trigger_ota_update=AsyncMock(),
trigger_reboot=AsyncMock(),
initialized=True,
shutdown=AsyncMock(),
)

View File

@ -1,6 +1,7 @@
"""Tests for Shelly button platform."""
from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN
from homeassistant.components.button.const import SERVICE_PRESS
from homeassistant.components.shelly.const import DOMAIN
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_registry import async_get
@ -10,6 +11,14 @@ async def test_block_button(hass: HomeAssistant, coap_wrapper):
"""Test block device OTA button."""
assert coap_wrapper
entity_registry = async_get(hass)
entity_registry.async_get_or_create(
BUTTON_DOMAIN,
DOMAIN,
"test_name_ota_update_beta",
suggested_object_id="test_name_ota_update_beta",
disabled_by=None,
)
hass.async_create_task(
hass.config_entries.async_forward_entry_setup(coap_wrapper.entry, BUTTON_DOMAIN)
)
@ -27,21 +36,54 @@ async def test_block_button(hass: HomeAssistant, coap_wrapper):
blocking=True,
)
await hass.async_block_till_done()
coap_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False)
assert coap_wrapper.device.trigger_ota_update.call_count == 1
coap_wrapper.device.trigger_ota_update.assert_called_with(beta=False)
# beta channel button
entity_registry = async_get(hass)
entry = entity_registry.async_get("button.test_name_ota_update_beta")
state = hass.states.get("button.test_name_ota_update_beta")
assert entry
assert state is None
assert state
assert state.state == STATE_UNKNOWN
await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_ota_update_beta"},
blocking=True,
)
await hass.async_block_till_done()
assert coap_wrapper.device.trigger_ota_update.call_count == 2
coap_wrapper.device.trigger_ota_update.assert_called_with(beta=True)
# reboot button
state = hass.states.get("button.test_name_reboot")
assert state
assert state.state == STATE_UNKNOWN
await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_reboot"},
blocking=True,
)
await hass.async_block_till_done()
assert coap_wrapper.device.trigger_reboot.call_count == 1
async def test_rpc_button(hass: HomeAssistant, rpc_wrapper):
"""Test rpc device OTA button."""
assert rpc_wrapper
entity_registry = async_get(hass)
entity_registry.async_get_or_create(
BUTTON_DOMAIN,
DOMAIN,
"test_name_ota_update_beta",
suggested_object_id="test_name_ota_update_beta",
disabled_by=None,
)
hass.async_create_task(
hass.config_entries.async_forward_entry_setup(rpc_wrapper.entry, BUTTON_DOMAIN)
)
@ -59,12 +101,36 @@ async def test_rpc_button(hass: HomeAssistant, rpc_wrapper):
blocking=True,
)
await hass.async_block_till_done()
rpc_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False)
assert rpc_wrapper.device.trigger_ota_update.call_count == 1
rpc_wrapper.device.trigger_ota_update.assert_called_with(beta=False)
# beta channel button
entity_registry = async_get(hass)
entry = entity_registry.async_get("button.test_name_ota_update_beta")
state = hass.states.get("button.test_name_ota_update_beta")
assert entry
assert state is None
assert state
assert state.state == STATE_UNKNOWN
await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_ota_update_beta"},
blocking=True,
)
await hass.async_block_till_done()
assert rpc_wrapper.device.trigger_ota_update.call_count == 2
rpc_wrapper.device.trigger_ota_update.assert_called_with(beta=True)
# reboot button
state = hass.states.get("button.test_name_reboot")
assert state
assert state.state == STATE_UNKNOWN
await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_reboot"},
blocking=True,
)
await hass.async_block_till_done()
assert rpc_wrapper.device.trigger_reboot.call_count == 1