Strictly type modbus base_platform.py (#56343)

pull/56375/head
jan iversen 2021-09-20 14:59:30 +02:00 committed by GitHub
parent bb6f97c4d3
commit a84e86ff13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 19 deletions

View File

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from datetime import timedelta from datetime import datetime, timedelta
import logging import logging
import struct import struct
from typing import Any, Callable from typing import Any, Callable, cast
from homeassistant.const import ( from homeassistant.const import (
CONF_ADDRESS, CONF_ADDRESS,
@ -75,7 +75,7 @@ class BasePlatform(Entity):
self._slave = entry.get(CONF_SLAVE, 0) self._slave = entry.get(CONF_SLAVE, 0)
self._address = int(entry[CONF_ADDRESS]) self._address = int(entry[CONF_ADDRESS])
self._input_type = entry[CONF_INPUT_TYPE] self._input_type = entry[CONF_INPUT_TYPE]
self._value = None self._value: str | None = None
self._scan_interval = int(entry[CONF_SCAN_INTERVAL]) self._scan_interval = int(entry[CONF_SCAN_INTERVAL])
self._call_active = False self._call_active = False
self._cancel_timer: Callable[[], None] | None = None self._cancel_timer: Callable[[], None] | None = None
@ -90,7 +90,7 @@ class BasePlatform(Entity):
self._lazy_errors = self._lazy_error_count self._lazy_errors = self._lazy_error_count
@abstractmethod @abstractmethod
async def async_update(self, now=None): async def async_update(self, now: datetime | None = None) -> None:
"""Virtual function to be overwritten.""" """Virtual function to be overwritten."""
@callback @callback
@ -107,7 +107,7 @@ class BasePlatform(Entity):
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def async_hold(self, update=True) -> None: def async_hold(self, update: bool = True) -> None:
"""Remote stop entity.""" """Remote stop entity."""
if self._cancel_call: if self._cancel_call:
self._cancel_call() self._cancel_call()
@ -119,7 +119,7 @@ class BasePlatform(Entity):
self._attr_available = False self._attr_available = False
self.async_write_ha_state() self.async_write_ha_state()
async def async_base_added_to_hass(self): async def async_base_added_to_hass(self) -> None:
"""Handle entity which will be added.""" """Handle entity which will be added."""
self.async_run() self.async_run()
self.async_on_remove( self.async_on_remove(
@ -138,13 +138,13 @@ class BaseStructPlatform(BasePlatform, RestoreEntity):
super().__init__(hub, config) super().__init__(hub, config)
self._swap = config[CONF_SWAP] self._swap = config[CONF_SWAP]
self._data_type = config[CONF_DATA_TYPE] self._data_type = config[CONF_DATA_TYPE]
self._structure = config.get(CONF_STRUCTURE) self._structure: str = config[CONF_STRUCTURE]
self._precision = config[CONF_PRECISION] self._precision = config[CONF_PRECISION]
self._scale = config[CONF_SCALE] self._scale = config[CONF_SCALE]
self._offset = config[CONF_OFFSET] self._offset = config[CONF_OFFSET]
self._count = config[CONF_COUNT] self._count = config[CONF_COUNT]
def _swap_registers(self, registers): def _swap_registers(self, registers: list[int]) -> list[int]:
"""Do swap as needed.""" """Do swap as needed."""
if self._swap in (CONF_SWAP_BYTE, CONF_SWAP_WORD_BYTE): if self._swap in (CONF_SWAP_BYTE, CONF_SWAP_WORD_BYTE):
# convert [12][34] --> [21][43] # convert [12][34] --> [21][43]
@ -159,7 +159,7 @@ class BaseStructPlatform(BasePlatform, RestoreEntity):
registers.reverse() registers.reverse()
return registers return registers
def unpack_structure_result(self, registers): def unpack_structure_result(self, registers: list[int]) -> str:
"""Convert registers to proper result.""" """Convert registers to proper result."""
registers = self._swap_registers(registers) registers = self._swap_registers(registers)
@ -187,14 +187,14 @@ class BaseStructPlatform(BasePlatform, RestoreEntity):
return ",".join(map(str, v_result)) return ",".join(map(str, v_result))
# Apply scale and precision to floats and ints # Apply scale and precision to floats and ints
val = self._scale * val[0] + self._offset val_result: float | int = self._scale * val[0] + self._offset
# We could convert int to float, and the code would still work; however # We could convert int to float, and the code would still work; however
# we lose some precision, and unit tests will fail. Therefore, we do # we lose some precision, and unit tests will fail. Therefore, we do
# the conversion only when it's absolutely necessary. # the conversion only when it's absolutely necessary.
if isinstance(val, int) and self._precision == 0: if isinstance(val_result, int) and self._precision == 0:
return str(val) return str(val_result)
return f"{float(val):.{self._precision}f}" return f"{float(val_result):.{self._precision}f}"
class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity): class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
@ -225,7 +225,7 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
CALL_TYPE_WRITE_REGISTERS, CALL_TYPE_WRITE_REGISTERS,
), ),
} }
self._write_type = convert[config[CONF_WRITE_TYPE]][1] self._write_type = cast(str, convert[config[CONF_WRITE_TYPE]][1])
self.command_on = config[CONF_COMMAND_ON] self.command_on = config[CONF_COMMAND_ON]
self._command_off = config[CONF_COMMAND_OFF] self._command_off = config[CONF_COMMAND_OFF]
if CONF_VERIFY in config: if CONF_VERIFY in config:
@ -244,14 +244,14 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
else: else:
self._verify_active = False self._verify_active = False
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Handle entity which will be added.""" """Handle entity which will be added."""
await self.async_base_added_to_hass() await self.async_base_added_to_hass()
state = await self.async_get_last_state() state = await self.async_get_last_state()
if state: if state:
self._attr_is_on = state.state == STATE_ON self._attr_is_on = state.state == STATE_ON
async def async_turn(self, command): async def async_turn(self, command: int) -> None:
"""Evaluate switch result.""" """Evaluate switch result."""
result = await self._hub.async_pymodbus_call( result = await self._hub.async_pymodbus_call(
self._slave, self._address, command, self._write_type self._slave, self._address, command, self._write_type
@ -272,11 +272,11 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
else: else:
await self.async_update() await self.async_update()
async def async_turn_off(self, **kwargs): async def async_turn_off(self, **kwargs: Any) -> None:
"""Set switch off.""" """Set switch off."""
await self.async_turn(self._command_off) await self.async_turn(self._command_off)
async def async_update(self, now=None): async def async_update(self, now: datetime | None = None) -> None:
"""Update the entity state.""" """Update the entity state."""
# remark "now" is a dummy parameter to avoid problems with # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval # async_track_time_interval

View File

@ -8,6 +8,7 @@ import logging
from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient
from pymodbus.constants import Defaults from pymodbus.constants import Defaults
from pymodbus.exceptions import ModbusException from pymodbus.exceptions import ModbusException
from pymodbus.pdu import ModbusResponse
from pymodbus.transaction import ModbusRtuFramer from pymodbus.transaction import ModbusRtuFramer
from homeassistant.const import ( from homeassistant.const import (
@ -356,7 +357,13 @@ class ModbusHub:
self._in_error = False self._in_error = False
return result return result
async def async_pymodbus_call(self, unit, address, value, use_call): async def async_pymodbus_call(
self,
unit: str | int | None,
address: int,
value: str | int,
use_call: str | None,
) -> ModbusResponse | None:
"""Convert async to sync pymodbus call.""" """Convert async to sync pymodbus call."""
if self._config_delay: if self._config_delay:
return None return None