core/homeassistant/components/tradfri/base_class.py

131 lines
4.2 KiB
Python
Raw Normal View History

"""Base class for IKEA TRADFRI."""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
import logging
from typing import Any
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.device.air_purifier import AirPurifier
from pytradfri.device.air_purifier_control import AirPurifierControl
from pytradfri.device.blind import Blind
from pytradfri.device.blind_control import BlindControl
from pytradfri.device.light import Light
from pytradfri.device.light_control import LightControl
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl
from pytradfri.error import PytradfriError
from homeassistant.core import callback
from homeassistant.helpers.entity import DeviceInfo, Entity
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
def handle_error(
func: Callable[[Command | list[Command]], Any]
) -> Callable[[str], Any]:
"""Handle tradfri api call error."""
@wraps(func)
async def wrapper(command: Command | list[Command]) -> None:
"""Decorate api call."""
try:
await func(command)
except PytradfriError as err:
_LOGGER.error("Unable to execute command %s: %s", command, err)
return wrapper
class TradfriBaseClass(Entity):
"""Base class for IKEA TRADFRI.
All devices and groups should ultimately inherit from this class.
"""
_attr_should_poll = False
def __init__(
self,
device: Device,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a device."""
self._api = handle_error(api)
self._attr_name = device.name
self._attr_available = device.reachable
self._device: Device = device
self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | AirPurifierControl | None = (
None
)
self._device_data: Socket | Light | Blind | AirPurifier | None = None
self._gateway_id = gateway_id
async def _async_run_observe(self, cmd: Command) -> None:
"""Run observe in a coroutine."""
try:
await self._api(cmd)
except PytradfriError as err:
self._attr_available = False
self.async_write_ha_state()
_LOGGER.warning("Observation failed, trying again", exc_info=err)
self._async_start_observe()
@callback
def _async_start_observe(self, exc: Exception | None = None) -> None:
"""Start observation of device."""
if exc:
self._attr_available = False
self.async_write_ha_state()
_LOGGER.warning("Observation failed for %s", self._attr_name, exc_info=exc)
cmd = self._device.observe(
callback=self._observe_update,
err_callback=self._async_start_observe,
duration=0,
)
self.hass.async_create_task(self._async_run_observe(cmd))
async def async_added_to_hass(self) -> None:
"""Start thread when added to hass."""
self._async_start_observe()
@callback
def _observe_update(self, device: Device) -> None:
"""Receive new state data for this device."""
self._refresh(device)
def _refresh(self, device: Device, write_ha: bool = True) -> None:
"""Refresh the device data."""
self._device = device
self._attr_name = device.name
self._attr_available = device.reachable
if write_ha:
self.async_write_ha_state()
class TradfriBaseDevice(TradfriBaseClass):
"""Base class for a TRADFRI device.
All devices should inherit from this class.
"""
@property
def device_info(self) -> DeviceInfo:
"""Return the device info."""
info = self._device.device_info
return DeviceInfo(
identifiers={(DOMAIN, self._device.id)},
manufacturer=info.manufacturer,
model=info.model_number,
name=self._attr_name,
sw_version=info.firmware_version,
via_device=(DOMAIN, self._gateway_id),
)