Make tradfri base_class.py strictly typed (#56341)

* Make base_class.py strictly typed.
pull/56375/head
jan iversen 2021-09-20 14:33:50 +02:00 committed by GitHub
parent 5c717cbb1d
commit 93e9a67d7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 15 deletions

View File

@ -6,25 +6,31 @@ import logging
from typing import Any, Callable from typing import Any, Callable
from pytradfri.command import Command from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.device.blind import Blind from pytradfri.device.blind import Blind
from pytradfri.device.blind_control import BlindControl
from pytradfri.device.light import Light from pytradfri.device.light import Light
from pytradfri.device.light_control import LightControl
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
from pytradfri.device.socket import Socket from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl from pytradfri.device.socket_control import SocketControl
from pytradfri.error import PytradfriError from pytradfri.error import PytradfriError
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import DeviceInfo, Entity
from .const import DOMAIN from .const import DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def handle_error(func): def handle_error(
func: Callable[[Command | list[Command]], Any]
) -> Callable[[str], Any]:
"""Handle tradfri api call error.""" """Handle tradfri api call error."""
@wraps(func) @wraps(func)
async def wrapper(command): async def wrapper(command: Command | list[Command]) -> None:
"""Decorate api call.""" """Decorate api call."""
try: try:
await func(command) await func(command)
@ -43,18 +49,23 @@ class TradfriBaseClass(Entity):
_attr_should_poll = False _attr_should_poll = False
def __init__( def __init__(
self, device: Command, api: Callable[[str], Any], gateway_id: str self,
device: Device,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None: ) -> None:
"""Initialize a device.""" """Initialize a device."""
self._api = handle_error(api) self._api = handle_error(api)
self._device: Command | None = None self._device: Device = device
self._device_control: SocketControl | None = None self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | None = (
None
)
self._device_data: Socket | Light | Blind | None = None self._device_data: Socket | Light | Blind | None = None
self._gateway_id = gateway_id self._gateway_id = gateway_id
self._refresh(device) self._refresh(device)
@callback @callback
def _async_start_observe(self, exc=None): def _async_start_observe(self, exc: Exception | None = None) -> None:
"""Start observation of device.""" """Start observation of device."""
if exc: if exc:
self.async_write_ha_state() self.async_write_ha_state()
@ -71,17 +82,17 @@ class TradfriBaseClass(Entity):
_LOGGER.warning("Observation failed, trying again", exc_info=err) _LOGGER.warning("Observation failed, trying again", exc_info=err)
self._async_start_observe() self._async_start_observe()
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Start thread when added to hass.""" """Start thread when added to hass."""
self._async_start_observe() self._async_start_observe()
@callback @callback
def _observe_update(self, device): def _observe_update(self, device: Device) -> None:
"""Receive new state data for this device.""" """Receive new state data for this device."""
self._refresh(device) self._refresh(device)
self.async_write_ha_state() self.async_write_ha_state()
def _refresh(self, device: Command) -> None: def _refresh(self, device: Device) -> None:
"""Refresh the device data.""" """Refresh the device data."""
self._device = device self._device = device
self._attr_name = device.name self._attr_name = device.name
@ -94,10 +105,9 @@ class TradfriBaseDevice(TradfriBaseClass):
""" """
@property @property
def device_info(self): def device_info(self) -> DeviceInfo:
"""Return the device info.""" """Return the device info."""
info = self._device.device_info info = self._device.device_info
return { return {
"identifiers": {(DOMAIN, self._device.id)}, "identifiers": {(DOMAIN, self._device.id)},
"manufacturer": info.manufacturer, "manufacturer": info.manufacturer,
@ -107,7 +117,7 @@ class TradfriBaseDevice(TradfriBaseClass):
"via_device": (DOMAIN, self._gateway_id), "via_device": (DOMAIN, self._gateway_id),
} }
def _refresh(self, device: Command) -> None: def _refresh(self, device: Device) -> None:
"""Refresh the device data.""" """Refresh the device data."""
super()._refresh(device) super()._refresh(device)
self._attr_available = device.reachable self._attr_available = device.reachable

View File

@ -45,7 +45,10 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity):
_attr_native_unit_of_measurement = PERCENTAGE _attr_native_unit_of_measurement = PERCENTAGE
def __init__( def __init__(
self, device: Command, api: Callable[[str], Any], gateway_id: str self,
device: Command,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None: ) -> None:
"""Initialize the device.""" """Initialize the device."""
super().__init__(device, api, gateway_id) super().__init__(device, api, gateway_id)

View File

@ -36,7 +36,10 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__( def __init__(
self, device: Command, api: Callable[[str], Any], gateway_id: str self,
device: Command,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None: ) -> None:
"""Initialize a switch.""" """Initialize a switch."""
super().__init__(device, api, gateway_id) super().__init__(device, api, gateway_id)