From 93e9a67d7d47279b67e0966f9cffefec9dd9f90d Mon Sep 17 00:00:00 2001 From: jan iversen Date: Mon, 20 Sep 2021 14:33:50 +0200 Subject: [PATCH] Make tradfri base_class.py strictly typed (#56341) * Make base_class.py strictly typed. --- .../components/tradfri/base_class.py | 36 ++++++++++++------- homeassistant/components/tradfri/sensor.py | 5 ++- homeassistant/components/tradfri/switch.py | 5 ++- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/tradfri/base_class.py b/homeassistant/components/tradfri/base_class.py index eb1884cfc1b..1e86be6c1a5 100644 --- a/homeassistant/components/tradfri/base_class.py +++ b/homeassistant/components/tradfri/base_class.py @@ -6,25 +6,31 @@ import logging from typing import Any, Callable from pytradfri.command import Command +from pytradfri.device import Device from pytradfri.device.blind import Blind +from pytradfri.device.blind_control import BlindControl from pytradfri.device.light import Light +from pytradfri.device.light_control import LightControl +from pytradfri.device.signal_repeater_control import SignalRepeaterControl from pytradfri.device.socket import Socket from pytradfri.device.socket_control import SocketControl from pytradfri.error import PytradfriError from homeassistant.core import callback -from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity import DeviceInfo, Entity from .const import DOMAIN _LOGGER = logging.getLogger(__name__) -def handle_error(func): +def handle_error( + func: Callable[[Command | list[Command]], Any] +) -> Callable[[str], Any]: """Handle tradfri api call error.""" @wraps(func) - async def wrapper(command): + async def wrapper(command: Command | list[Command]) -> None: """Decorate api call.""" try: await func(command) @@ -43,18 +49,23 @@ class TradfriBaseClass(Entity): _attr_should_poll = False def __init__( - self, device: Command, api: Callable[[str], Any], gateway_id: str + self, + device: Device, + api: Callable[[Command | list[Command]], Any], + gateway_id: str, ) -> None: """Initialize a device.""" self._api = handle_error(api) - self._device: Command | None = None - self._device_control: SocketControl | None = None + self._device: Device = device + self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | None = ( + None + ) self._device_data: Socket | Light | Blind | None = None self._gateway_id = gateway_id self._refresh(device) @callback - def _async_start_observe(self, exc=None): + def _async_start_observe(self, exc: Exception | None = None) -> None: """Start observation of device.""" if exc: self.async_write_ha_state() @@ -71,17 +82,17 @@ class TradfriBaseClass(Entity): _LOGGER.warning("Observation failed, trying again", exc_info=err) self._async_start_observe() - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Start thread when added to hass.""" self._async_start_observe() @callback - def _observe_update(self, device): + def _observe_update(self, device: Device) -> None: """Receive new state data for this device.""" self._refresh(device) self.async_write_ha_state() - def _refresh(self, device: Command) -> None: + def _refresh(self, device: Device) -> None: """Refresh the device data.""" self._device = device self._attr_name = device.name @@ -94,10 +105,9 @@ class TradfriBaseDevice(TradfriBaseClass): """ @property - def device_info(self): + def device_info(self) -> DeviceInfo: """Return the device info.""" info = self._device.device_info - return { "identifiers": {(DOMAIN, self._device.id)}, "manufacturer": info.manufacturer, @@ -107,7 +117,7 @@ class TradfriBaseDevice(TradfriBaseClass): "via_device": (DOMAIN, self._gateway_id), } - def _refresh(self, device: Command) -> None: + def _refresh(self, device: Device) -> None: """Refresh the device data.""" super()._refresh(device) self._attr_available = device.reachable diff --git a/homeassistant/components/tradfri/sensor.py b/homeassistant/components/tradfri/sensor.py index 1e7d771cb39..23b7ecc2fab 100644 --- a/homeassistant/components/tradfri/sensor.py +++ b/homeassistant/components/tradfri/sensor.py @@ -45,7 +45,10 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity): _attr_native_unit_of_measurement = PERCENTAGE def __init__( - self, device: Command, api: Callable[[str], Any], gateway_id: str + self, + device: Command, + api: Callable[[Command | list[Command]], Any], + gateway_id: str, ) -> None: """Initialize the device.""" super().__init__(device, api, gateway_id) diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index 6dc934814f0..7366bf7a898 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -36,7 +36,10 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity): """The platform class required by Home Assistant.""" def __init__( - self, device: Command, api: Callable[[str], Any], gateway_id: str + self, + device: Command, + api: Callable[[Command | list[Command]], Any], + gateway_id: str, ) -> None: """Initialize a switch.""" super().__init__(device, api, gateway_id)