Make tradfri base_class.py strictly typed (#56341)
* Make base_class.py strictly typed.pull/56375/head
parent
5c717cbb1d
commit
93e9a67d7d
|
@ -6,25 +6,31 @@ import logging
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from pytradfri.command import Command
|
from pytradfri.command import Command
|
||||||
|
from pytradfri.device import Device
|
||||||
from pytradfri.device.blind import Blind
|
from pytradfri.device.blind import Blind
|
||||||
|
from pytradfri.device.blind_control import BlindControl
|
||||||
from pytradfri.device.light import Light
|
from pytradfri.device.light import Light
|
||||||
|
from pytradfri.device.light_control import LightControl
|
||||||
|
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
|
||||||
from pytradfri.device.socket import Socket
|
from pytradfri.device.socket import Socket
|
||||||
from pytradfri.device.socket_control import SocketControl
|
from pytradfri.device.socket_control import SocketControl
|
||||||
from pytradfri.error import PytradfriError
|
from pytradfri.error import PytradfriError
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import DeviceInfo, Entity
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def handle_error(func):
|
def handle_error(
|
||||||
|
func: Callable[[Command | list[Command]], Any]
|
||||||
|
) -> Callable[[str], Any]:
|
||||||
"""Handle tradfri api call error."""
|
"""Handle tradfri api call error."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(command):
|
async def wrapper(command: Command | list[Command]) -> None:
|
||||||
"""Decorate api call."""
|
"""Decorate api call."""
|
||||||
try:
|
try:
|
||||||
await func(command)
|
await func(command)
|
||||||
|
@ -43,18 +49,23 @@ class TradfriBaseClass(Entity):
|
||||||
_attr_should_poll = False
|
_attr_should_poll = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
self,
|
||||||
|
device: Device,
|
||||||
|
api: Callable[[Command | list[Command]], Any],
|
||||||
|
gateway_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a device."""
|
"""Initialize a device."""
|
||||||
self._api = handle_error(api)
|
self._api = handle_error(api)
|
||||||
self._device: Command | None = None
|
self._device: Device = device
|
||||||
self._device_control: SocketControl | None = None
|
self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | None = (
|
||||||
|
None
|
||||||
|
)
|
||||||
self._device_data: Socket | Light | Blind | None = None
|
self._device_data: Socket | Light | Blind | None = None
|
||||||
self._gateway_id = gateway_id
|
self._gateway_id = gateway_id
|
||||||
self._refresh(device)
|
self._refresh(device)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_start_observe(self, exc=None):
|
def _async_start_observe(self, exc: Exception | None = None) -> None:
|
||||||
"""Start observation of device."""
|
"""Start observation of device."""
|
||||||
if exc:
|
if exc:
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
@ -71,17 +82,17 @@ class TradfriBaseClass(Entity):
|
||||||
_LOGGER.warning("Observation failed, trying again", exc_info=err)
|
_LOGGER.warning("Observation failed, trying again", exc_info=err)
|
||||||
self._async_start_observe()
|
self._async_start_observe()
|
||||||
|
|
||||||
async def async_added_to_hass(self):
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Start thread when added to hass."""
|
"""Start thread when added to hass."""
|
||||||
self._async_start_observe()
|
self._async_start_observe()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _observe_update(self, device):
|
def _observe_update(self, device: Device) -> None:
|
||||||
"""Receive new state data for this device."""
|
"""Receive new state data for this device."""
|
||||||
self._refresh(device)
|
self._refresh(device)
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
def _refresh(self, device: Command) -> None:
|
def _refresh(self, device: Device) -> None:
|
||||||
"""Refresh the device data."""
|
"""Refresh the device data."""
|
||||||
self._device = device
|
self._device = device
|
||||||
self._attr_name = device.name
|
self._attr_name = device.name
|
||||||
|
@ -94,10 +105,9 @@ class TradfriBaseDevice(TradfriBaseClass):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_info(self):
|
def device_info(self) -> DeviceInfo:
|
||||||
"""Return the device info."""
|
"""Return the device info."""
|
||||||
info = self._device.device_info
|
info = self._device.device_info
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"identifiers": {(DOMAIN, self._device.id)},
|
"identifiers": {(DOMAIN, self._device.id)},
|
||||||
"manufacturer": info.manufacturer,
|
"manufacturer": info.manufacturer,
|
||||||
|
@ -107,7 +117,7 @@ class TradfriBaseDevice(TradfriBaseClass):
|
||||||
"via_device": (DOMAIN, self._gateway_id),
|
"via_device": (DOMAIN, self._gateway_id),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _refresh(self, device: Command) -> None:
|
def _refresh(self, device: Device) -> None:
|
||||||
"""Refresh the device data."""
|
"""Refresh the device data."""
|
||||||
super()._refresh(device)
|
super()._refresh(device)
|
||||||
self._attr_available = device.reachable
|
self._attr_available = device.reachable
|
||||||
|
|
|
@ -45,7 +45,10 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity):
|
||||||
_attr_native_unit_of_measurement = PERCENTAGE
|
_attr_native_unit_of_measurement = PERCENTAGE
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
self,
|
||||||
|
device: Command,
|
||||||
|
api: Callable[[Command | list[Command]], Any],
|
||||||
|
gateway_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the device."""
|
"""Initialize the device."""
|
||||||
super().__init__(device, api, gateway_id)
|
super().__init__(device, api, gateway_id)
|
||||||
|
|
|
@ -36,7 +36,10 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
|
||||||
"""The platform class required by Home Assistant."""
|
"""The platform class required by Home Assistant."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Command, api: Callable[[str], Any], gateway_id: str
|
self,
|
||||||
|
device: Command,
|
||||||
|
api: Callable[[Command | list[Command]], Any],
|
||||||
|
gateway_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a switch."""
|
"""Initialize a switch."""
|
||||||
super().__init__(device, api, gateway_id)
|
super().__init__(device, api, gateway_id)
|
||||||
|
|
Loading…
Reference in New Issue