Make tradfri base_class.py strictly typed (#56341)
* Make base_class.py strictly typed.pull/56375/head
parent
5c717cbb1d
commit
93e9a67d7d
|
@ -6,25 +6,31 @@ import logging
|
|||
from typing import Any, Callable
|
||||
|
||||
from pytradfri.command import Command
|
||||
from pytradfri.device import Device
|
||||
from pytradfri.device.blind import Blind
|
||||
from pytradfri.device.blind_control import BlindControl
|
||||
from pytradfri.device.light import Light
|
||||
from pytradfri.device.light_control import LightControl
|
||||
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
|
||||
from pytradfri.device.socket import Socket
|
||||
from pytradfri.device.socket_control import SocketControl
|
||||
from pytradfri.error import PytradfriError
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity import DeviceInfo, Entity
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_error(func):
|
||||
def handle_error(
|
||||
func: Callable[[Command | list[Command]], Any]
|
||||
) -> Callable[[str], Any]:
|
||||
"""Handle tradfri api call error."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(command):
|
||||
async def wrapper(command: Command | list[Command]) -> None:
|
||||
"""Decorate api call."""
|
||||
try:
|
||||
await func(command)
|
||||
|
@ -43,18 +49,23 @@ class TradfriBaseClass(Entity):
|
|||
_attr_should_poll = False
|
||||
|
||||
def __init__(
|
||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
||||
self,
|
||||
device: Device,
|
||||
api: Callable[[Command | list[Command]], Any],
|
||||
gateway_id: str,
|
||||
) -> None:
|
||||
"""Initialize a device."""
|
||||
self._api = handle_error(api)
|
||||
self._device: Command | None = None
|
||||
self._device_control: SocketControl | None = None
|
||||
self._device: Device = device
|
||||
self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | None = (
|
||||
None
|
||||
)
|
||||
self._device_data: Socket | Light | Blind | None = None
|
||||
self._gateway_id = gateway_id
|
||||
self._refresh(device)
|
||||
|
||||
@callback
|
||||
def _async_start_observe(self, exc=None):
|
||||
def _async_start_observe(self, exc: Exception | None = None) -> None:
|
||||
"""Start observation of device."""
|
||||
if exc:
|
||||
self.async_write_ha_state()
|
||||
|
@ -71,17 +82,17 @@ class TradfriBaseClass(Entity):
|
|||
_LOGGER.warning("Observation failed, trying again", exc_info=err)
|
||||
self._async_start_observe()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Start thread when added to hass."""
|
||||
self._async_start_observe()
|
||||
|
||||
@callback
|
||||
def _observe_update(self, device):
|
||||
def _observe_update(self, device: Device) -> None:
|
||||
"""Receive new state data for this device."""
|
||||
self._refresh(device)
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _refresh(self, device: Command) -> None:
|
||||
def _refresh(self, device: Device) -> None:
|
||||
"""Refresh the device data."""
|
||||
self._device = device
|
||||
self._attr_name = device.name
|
||||
|
@ -94,10 +105,9 @@ class TradfriBaseDevice(TradfriBaseClass):
|
|||
"""
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
def device_info(self) -> DeviceInfo:
|
||||
"""Return the device info."""
|
||||
info = self._device.device_info
|
||||
|
||||
return {
|
||||
"identifiers": {(DOMAIN, self._device.id)},
|
||||
"manufacturer": info.manufacturer,
|
||||
|
@ -107,7 +117,7 @@ class TradfriBaseDevice(TradfriBaseClass):
|
|||
"via_device": (DOMAIN, self._gateway_id),
|
||||
}
|
||||
|
||||
def _refresh(self, device: Command) -> None:
|
||||
def _refresh(self, device: Device) -> None:
|
||||
"""Refresh the device data."""
|
||||
super()._refresh(device)
|
||||
self._attr_available = device.reachable
|
||||
|
|
|
@ -45,7 +45,10 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity):
|
|||
_attr_native_unit_of_measurement = PERCENTAGE
|
||||
|
||||
def __init__(
|
||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
||||
self,
|
||||
device: Command,
|
||||
api: Callable[[Command | list[Command]], Any],
|
||||
gateway_id: str,
|
||||
) -> None:
|
||||
"""Initialize the device."""
|
||||
super().__init__(device, api, gateway_id)
|
||||
|
|
|
@ -36,7 +36,10 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
|
|||
"""The platform class required by Home Assistant."""
|
||||
|
||||
def __init__(
|
||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
||||
self,
|
||||
device: Command,
|
||||
api: Callable[[Command | list[Command]], Any],
|
||||
gateway_id: str,
|
||||
) -> None:
|
||||
"""Initialize a switch."""
|
||||
super().__init__(device, api, gateway_id)
|
||||
|
|
Loading…
Reference in New Issue