Fix modbus blocking threads (#50619)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/50455/head
parent
990b7c371f
commit
ad7be91b6a
|
@ -101,7 +101,7 @@ from .const import (
|
|||
MODBUS_DOMAIN as DOMAIN,
|
||||
PLATFORMS,
|
||||
)
|
||||
from .modbus import modbus_setup
|
||||
from .modbus import async_modbus_setup
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -350,8 +350,8 @@ SERVICE_WRITE_COIL_SCHEMA = vol.Schema(
|
|||
)
|
||||
|
||||
|
||||
def setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Set up Modbus component."""
|
||||
return modbus_setup(
|
||||
return await async_modbus_setup(
|
||||
hass, config, SERVICE_WRITE_REGISTER_SCHEMA, SERVICE_WRITE_COIL_SCHEMA
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ from .const import (
|
|||
MODBUS_DOMAIN,
|
||||
)
|
||||
|
||||
PARALLEL_UPDATES = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -114,9 +115,7 @@ class ModbusBinarySensor(BinarySensorEntity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity which will be added."""
|
||||
async_track_time_interval(
|
||||
self._hass, lambda arg: self.update(), self._scan_interval
|
||||
)
|
||||
async_track_time_interval(self._hass, self.async_update, self._scan_interval)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -148,17 +147,21 @@ class ModbusBinarySensor(BinarySensorEntity):
|
|||
"""Return True if entity is available."""
|
||||
return self._available
|
||||
|
||||
def update(self):
|
||||
async def async_update(self, now=None):
|
||||
"""Update the state of the sensor."""
|
||||
# remark "now" is a dummy parameter to avoid problems with
|
||||
# async_track_time_interval
|
||||
if self._input_type == CALL_TYPE_COIL:
|
||||
result = self._hub.read_coils(self._slave, self._address, 1)
|
||||
result = await self._hub.async_read_coils(self._slave, self._address, 1)
|
||||
else:
|
||||
result = self._hub.read_discrete_inputs(self._slave, self._address, 1)
|
||||
result = await self._hub.async_read_discrete_inputs(
|
||||
self._slave, self._address, 1
|
||||
)
|
||||
if result is None:
|
||||
self._available = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._value = result.bits[0] & 1
|
||||
self._available = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -46,6 +46,7 @@ from .const import (
|
|||
)
|
||||
from .modbus import ModbusHub
|
||||
|
||||
PARALLEL_UPDATES = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -132,9 +133,7 @@ class ModbusThermostat(ClimateEntity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity which will be added."""
|
||||
async_track_time_interval(
|
||||
self.hass, lambda arg: self.update(), self._scan_interval
|
||||
)
|
||||
async_track_time_interval(self.hass, self.async_update, self._scan_interval)
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
|
@ -160,7 +159,7 @@ class ModbusThermostat(ClimateEntity):
|
|||
"""Return the possible HVAC modes."""
|
||||
return [HVAC_MODE_AUTO]
|
||||
|
||||
def set_hvac_mode(self, hvac_mode: str) -> None:
|
||||
async def async_set_hvac_mode(self, hvac_mode: str) -> None:
|
||||
"""Set new target hvac mode."""
|
||||
# Home Assistant expects this method.
|
||||
# We'll keep it here to avoid getting exceptions.
|
||||
|
@ -200,7 +199,7 @@ class ModbusThermostat(ClimateEntity):
|
|||
"""Return the supported step of target temperature."""
|
||||
return self._temp_step
|
||||
|
||||
def set_temperature(self, **kwargs):
|
||||
async def async_set_temperature(self, **kwargs):
|
||||
"""Set new target temperature."""
|
||||
if ATTR_TEMPERATURE not in kwargs:
|
||||
return
|
||||
|
@ -209,35 +208,39 @@ class ModbusThermostat(ClimateEntity):
|
|||
)
|
||||
byte_string = struct.pack(self._structure, target_temperature)
|
||||
register_value = struct.unpack(">h", byte_string[0:2])[0]
|
||||
self._available = self._hub.write_registers(
|
||||
self._available = await self._hub.async_write_registers(
|
||||
self._slave,
|
||||
self._target_temperature_register,
|
||||
register_value,
|
||||
)
|
||||
self.update()
|
||||
self.async_update()
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return True if entity is available."""
|
||||
return self._available
|
||||
|
||||
def update(self):
|
||||
async def async_update(self, now=None):
|
||||
"""Update Target & Current Temperature."""
|
||||
self._target_temperature = self._read_register(
|
||||
# remark "now" is a dummy parameter to avoid problems with
|
||||
# async_track_time_interval
|
||||
self._target_temperature = await self._async_read_register(
|
||||
CALL_TYPE_REGISTER_HOLDING, self._target_temperature_register
|
||||
)
|
||||
self._current_temperature = self._read_register(
|
||||
self._current_temperature = await self._async_read_register(
|
||||
self._current_temperature_register_type, self._current_temperature_register
|
||||
)
|
||||
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _read_register(self, register_type, register) -> float | None:
|
||||
async def _async_read_register(self, register_type, register) -> float | None:
|
||||
"""Read register using the Modbus hub slave."""
|
||||
if register_type == CALL_TYPE_REGISTER_INPUT:
|
||||
result = self._hub.read_input_registers(self._slave, register, self._count)
|
||||
result = await self._hub.async_read_input_registers(
|
||||
self._slave, register, self._count
|
||||
)
|
||||
else:
|
||||
result = self._hub.read_holding_registers(
|
||||
result = await self._hub.async_read_holding_registers(
|
||||
self._slave, register, self._count
|
||||
)
|
||||
if result is None:
|
||||
|
|
|
@ -33,6 +33,7 @@ from .const import (
|
|||
)
|
||||
from .modbus import ModbusHub
|
||||
|
||||
PARALLEL_UPDATES = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -106,9 +107,7 @@ class ModbusCover(CoverEntity, RestoreEntity):
|
|||
if state:
|
||||
self._value = state.state
|
||||
|
||||
async_track_time_interval(
|
||||
self.hass, lambda arg: self.update(), self._scan_interval
|
||||
)
|
||||
async_track_time_interval(self.hass, self.async_update, self._scan_interval)
|
||||
|
||||
@property
|
||||
def device_class(self) -> str | None:
|
||||
|
@ -154,41 +153,43 @@ class ModbusCover(CoverEntity, RestoreEntity):
|
|||
# Handle polling directly in this entity
|
||||
return False
|
||||
|
||||
def open_cover(self, **kwargs: Any) -> None:
|
||||
async def async_open_cover(self, **kwargs: Any) -> None:
|
||||
"""Open cover."""
|
||||
if self._coil is not None:
|
||||
self._write_coil(True)
|
||||
await self._async_write_coil(True)
|
||||
else:
|
||||
self._write_register(self._state_open)
|
||||
await self._async_write_register(self._state_open)
|
||||
|
||||
self.update()
|
||||
self.async_update()
|
||||
|
||||
def close_cover(self, **kwargs: Any) -> None:
|
||||
async def async_close_cover(self, **kwargs: Any) -> None:
|
||||
"""Close cover."""
|
||||
if self._coil is not None:
|
||||
self._write_coil(False)
|
||||
await self._async_write_coil(False)
|
||||
else:
|
||||
self._write_register(self._state_closed)
|
||||
await self._async_write_register(self._state_closed)
|
||||
|
||||
self.update()
|
||||
self.async_update()
|
||||
|
||||
def update(self):
|
||||
async def async_update(self, now=None):
|
||||
"""Update the state of the cover."""
|
||||
# remark "now" is a dummy parameter to avoid problems with
|
||||
# async_track_time_interval
|
||||
if self._coil is not None and self._status_register is None:
|
||||
self._value = self._read_coil()
|
||||
self._value = await self._async_read_coil()
|
||||
else:
|
||||
self._value = self._read_status_register()
|
||||
self._value = await self._async_read_status_register()
|
||||
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _read_status_register(self) -> int | None:
|
||||
async def _async_read_status_register(self) -> int | None:
|
||||
"""Read status register using the Modbus hub slave."""
|
||||
if self._status_register_type == CALL_TYPE_REGISTER_INPUT:
|
||||
result = self._hub.read_input_registers(
|
||||
result = await self._hub.async_read_input_registers(
|
||||
self._slave, self._status_register, 1
|
||||
)
|
||||
else:
|
||||
result = self._hub.read_holding_registers(
|
||||
result = await self._hub.async_read_holding_registers(
|
||||
self._slave, self._status_register, 1
|
||||
)
|
||||
if result is None:
|
||||
|
@ -200,13 +201,15 @@ class ModbusCover(CoverEntity, RestoreEntity):
|
|||
|
||||
return value
|
||||
|
||||
def _write_register(self, value):
|
||||
async def _async_write_register(self, value):
|
||||
"""Write holding register using the Modbus hub slave."""
|
||||
self._available = self._hub.write_register(self._slave, self._register, value)
|
||||
self._available = await self._hub.async_write_register(
|
||||
self._slave, self._register, value
|
||||
)
|
||||
|
||||
def _read_coil(self) -> bool | None:
|
||||
async def _async_read_coil(self) -> bool | None:
|
||||
"""Read coil using the Modbus hub slave."""
|
||||
result = self._hub.read_coils(self._slave, self._coil, 1)
|
||||
result = await self._hub.async_read_coils(self._slave, self._coil, 1)
|
||||
if result is None:
|
||||
self._available = False
|
||||
return None
|
||||
|
@ -214,6 +217,8 @@ class ModbusCover(CoverEntity, RestoreEntity):
|
|||
value = bool(result.bits[0] & 1)
|
||||
return value
|
||||
|
||||
def _write_coil(self, value):
|
||||
async def _async_write_coil(self, value):
|
||||
"""Write coil using the Modbus hub slave."""
|
||||
self._available = self._hub.write_coil(self._slave, self._coil, value)
|
||||
self._available = await self._hub.async_write_coil(
|
||||
self._slave, self._coil, value
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Support for Modbus."""
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient
|
||||
from pymodbus.constants import Defaults
|
||||
|
@ -17,8 +17,9 @@ from homeassistant.const import (
|
|||
CONF_TYPE,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.helpers.discovery import load_platform
|
||||
from homeassistant.helpers.event import call_later
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
from .const import (
|
||||
ATTR_ADDRESS,
|
||||
|
@ -41,32 +42,37 @@ from .const import (
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def modbus_setup(
|
||||
async def async_modbus_setup(
|
||||
hass, config, service_write_register_schema, service_write_coil_schema
|
||||
):
|
||||
"""Set up Modbus component."""
|
||||
|
||||
hass.data[DOMAIN] = hub_collect = {}
|
||||
for conf_hub in config[DOMAIN]:
|
||||
hub_collect[conf_hub[CONF_NAME]] = ModbusHub(conf_hub)
|
||||
my_hub = ModbusHub(hass, conf_hub)
|
||||
hub_collect[conf_hub[CONF_NAME]] = my_hub
|
||||
|
||||
# modbus needs to be activated before components are loaded
|
||||
# to avoid a racing problem
|
||||
hub_collect[conf_hub[CONF_NAME]].setup(hass)
|
||||
await my_hub.async_setup()
|
||||
|
||||
# load platforms
|
||||
for component, conf_key in PLATFORMS:
|
||||
if conf_key in conf_hub:
|
||||
load_platform(hass, component, DOMAIN, conf_hub, config)
|
||||
hass.async_create_task(
|
||||
async_load_platform(hass, component, DOMAIN, conf_hub, config)
|
||||
)
|
||||
|
||||
def stop_modbus(event):
|
||||
async def async_stop_modbus(event):
|
||||
"""Stop Modbus service."""
|
||||
|
||||
for client in hub_collect.values():
|
||||
client.close()
|
||||
await client.async_close()
|
||||
del client
|
||||
|
||||
def write_register(service):
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_modbus)
|
||||
|
||||
async def async_write_register(service):
|
||||
"""Write Modbus registers."""
|
||||
unit = int(float(service.data[ATTR_UNIT]))
|
||||
address = int(float(service.data[ATTR_ADDRESS]))
|
||||
|
@ -75,13 +81,22 @@ def modbus_setup(
|
|||
service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB
|
||||
)
|
||||
if isinstance(value, list):
|
||||
hub_collect[client_name].write_registers(
|
||||
await hub_collect[client_name].async_write_registers(
|
||||
unit, address, [int(float(i)) for i in value]
|
||||
)
|
||||
else:
|
||||
hub_collect[client_name].write_register(unit, address, int(float(value)))
|
||||
await hub_collect[client_name].async_write_register(
|
||||
unit, address, int(float(value))
|
||||
)
|
||||
|
||||
def write_coil(service):
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_WRITE_REGISTER,
|
||||
async_write_register,
|
||||
schema=service_write_register_schema,
|
||||
)
|
||||
|
||||
async def async_write_coil(service):
|
||||
"""Write Modbus coil."""
|
||||
unit = service.data[ATTR_UNIT]
|
||||
address = service.data[ATTR_ADDRESS]
|
||||
|
@ -90,22 +105,12 @@ def modbus_setup(
|
|||
service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB
|
||||
)
|
||||
if isinstance(state, list):
|
||||
hub_collect[client_name].write_coils(unit, address, state)
|
||||
await hub_collect[client_name].async_write_coils(unit, address, state)
|
||||
else:
|
||||
hub_collect[client_name].write_coil(unit, address, state)
|
||||
await hub_collect[client_name].async_write_coil(unit, address, state)
|
||||
|
||||
# register function to gracefully stop modbus
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_modbus)
|
||||
|
||||
# Register services for modbus
|
||||
hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_WRITE_REGISTER,
|
||||
write_register,
|
||||
schema=service_write_register_schema,
|
||||
)
|
||||
hass.services.register(
|
||||
DOMAIN, SERVICE_WRITE_COIL, write_coil, schema=service_write_coil_schema
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_WRITE_COIL, async_write_coil, schema=service_write_coil_schema
|
||||
)
|
||||
return True
|
||||
|
||||
|
@ -113,14 +118,15 @@ def modbus_setup(
|
|||
class ModbusHub:
|
||||
"""Thread safe wrapper class for pymodbus."""
|
||||
|
||||
def __init__(self, client_config):
|
||||
def __init__(self, hass, client_config):
|
||||
"""Initialize the Modbus hub."""
|
||||
|
||||
# generic configuration
|
||||
self._client = None
|
||||
self._cancel_listener = None
|
||||
self._async_cancel_listener = None
|
||||
self._in_error = False
|
||||
self._lock = threading.Lock()
|
||||
self._lock = asyncio.Lock()
|
||||
self.hass = hass
|
||||
self._config_name = client_config[CONF_NAME]
|
||||
self._config_type = client_config[CONF_TYPE]
|
||||
self._config_port = client_config[CONF_PORT]
|
||||
|
@ -152,7 +158,7 @@ class ModbusHub:
|
|||
_LOGGER.error(log_text)
|
||||
self._in_error = error_state
|
||||
|
||||
def setup(self, hass):
|
||||
async def async_setup(self):
|
||||
"""Set up pymodbus client."""
|
||||
try:
|
||||
if self._config_type == "serial":
|
||||
|
@ -193,166 +199,113 @@ class ModbusHub:
|
|||
self._log_error(exception_error, error_state=False)
|
||||
return
|
||||
|
||||
# Connect device
|
||||
self.connect()
|
||||
async with self._lock:
|
||||
await self.hass.async_add_executor_job(self._pymodbus_connect)
|
||||
|
||||
# Start counting down to allow modbus requests.
|
||||
if self._config_delay:
|
||||
self._cancel_listener = call_later(hass, self._config_delay, self.end_delay)
|
||||
self._async_cancel_listener = async_call_later(
|
||||
self.hass, self._config_delay, self.async_end_delay
|
||||
)
|
||||
|
||||
def end_delay(self, args):
|
||||
@callback
|
||||
def async_end_delay(self, args):
|
||||
"""End startup delay."""
|
||||
self._cancel_listener = None
|
||||
self._async_cancel_listener = None
|
||||
self._config_delay = 0
|
||||
|
||||
def close(self):
|
||||
"""Disconnect client."""
|
||||
if self._cancel_listener:
|
||||
self._cancel_listener()
|
||||
self._cancel_listener = None
|
||||
with self._lock:
|
||||
try:
|
||||
def _pymodbus_close(self):
|
||||
"""Close sync. pymodbus."""
|
||||
if self._client:
|
||||
try:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
except ModbusException as exception_error:
|
||||
self._log_error(exception_error)
|
||||
return
|
||||
self._client = None
|
||||
|
||||
def connect(self):
|
||||
async def async_close(self):
|
||||
"""Disconnect client."""
|
||||
if self._async_cancel_listener:
|
||||
self._async_cancel_listener()
|
||||
self._async_cancel_listener = None
|
||||
|
||||
async with self._lock:
|
||||
return await self.hass.async_add_executor_job(self._pymodbus_close)
|
||||
|
||||
def _pymodbus_connect(self):
|
||||
"""Connect client."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._client.connect()
|
||||
except ModbusException as exception_error:
|
||||
self._log_error(exception_error, error_state=False)
|
||||
return
|
||||
|
||||
def read_coils(self, unit, address, count):
|
||||
"""Read coils."""
|
||||
if self._config_delay:
|
||||
return None
|
||||
with self._lock:
|
||||
def _pymodbus_call(self, unit, address, value, check_attr, func):
|
||||
"""Call sync. pymodbus."""
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.read_coils(address, count, **kwargs)
|
||||
result = func(address, value, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
self._log_error(exception_error)
|
||||
result = exception_error
|
||||
if not hasattr(result, "bits"):
|
||||
if not hasattr(result, check_attr):
|
||||
self._log_error(result)
|
||||
return None
|
||||
self._in_error = False
|
||||
return result
|
||||
|
||||
def read_discrete_inputs(self, unit, address, count):
|
||||
async def async_pymodbus_call(self, unit, address, value, check_attr, func):
|
||||
"""Convert async to sync pymodbus call."""
|
||||
if self._config_delay:
|
||||
return None
|
||||
async with self._lock:
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._pymodbus_call, unit, address, value, check_attr, func
|
||||
)
|
||||
|
||||
async def async_read_coils(self, unit, address, count):
|
||||
"""Read coils."""
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, count, "bits", self._client.read_coils
|
||||
)
|
||||
|
||||
async def async_read_discrete_inputs(self, unit, address, count):
|
||||
"""Read discrete inputs."""
|
||||
if self._config_delay:
|
||||
return None
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.read_discrete_inputs(address, count, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "bits"):
|
||||
self._log_error(result)
|
||||
return None
|
||||
self._in_error = False
|
||||
return result
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, count, "bits", self._client.read_discrete_inputs
|
||||
)
|
||||
|
||||
def read_input_registers(self, unit, address, count):
|
||||
async def async_read_input_registers(self, unit, address, count):
|
||||
"""Read input registers."""
|
||||
if self._config_delay:
|
||||
return None
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.read_input_registers(address, count, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "registers"):
|
||||
self._log_error(result)
|
||||
return None
|
||||
self._in_error = False
|
||||
return result
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, count, "registers", self._client.read_input_registers
|
||||
)
|
||||
|
||||
def read_holding_registers(self, unit, address, count):
|
||||
async def async_read_holding_registers(self, unit, address, count):
|
||||
"""Read holding registers."""
|
||||
if self._config_delay:
|
||||
return None
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.read_holding_registers(address, count, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "registers"):
|
||||
self._log_error(result)
|
||||
return None
|
||||
self._in_error = False
|
||||
return result
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, count, "registers", self._client.read_holding_registers
|
||||
)
|
||||
|
||||
def write_coil(self, unit, address, value) -> bool:
|
||||
async def async_write_coil(self, unit, address, value) -> bool:
|
||||
"""Write coil."""
|
||||
if self._config_delay:
|
||||
return False
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.write_coil(address, value, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "value"):
|
||||
self._log_error(result)
|
||||
return False
|
||||
self._in_error = False
|
||||
return True
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, value, "value", self._client.write_coil
|
||||
)
|
||||
|
||||
def write_coils(self, unit, address, values) -> bool:
|
||||
async def async_write_coils(self, unit, address, values) -> bool:
|
||||
"""Write coil."""
|
||||
if self._config_delay:
|
||||
return False
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.write_coils(address, values, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "count"):
|
||||
self._log_error(result)
|
||||
return False
|
||||
self._in_error = False
|
||||
return True
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, values, "count", self._client.write_coils
|
||||
)
|
||||
|
||||
def write_register(self, unit, address, value) -> bool:
|
||||
async def async_write_register(self, unit, address, value) -> bool:
|
||||
"""Write register."""
|
||||
if self._config_delay:
|
||||
return False
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.write_register(address, value, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "value"):
|
||||
self._log_error(result)
|
||||
return False
|
||||
self._in_error = False
|
||||
return True
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, value, "value", self._client.write_register
|
||||
)
|
||||
|
||||
def write_registers(self, unit, address, values) -> bool:
|
||||
async def async_write_registers(self, unit, address, values) -> bool:
|
||||
"""Write registers."""
|
||||
if self._config_delay:
|
||||
return False
|
||||
with self._lock:
|
||||
kwargs = {"unit": unit} if unit else {}
|
||||
try:
|
||||
result = self._client.write_registers(address, values, **kwargs)
|
||||
except ModbusException as exception_error:
|
||||
result = exception_error
|
||||
if not hasattr(result, "count"):
|
||||
self._log_error(result)
|
||||
return False
|
||||
self._in_error = False
|
||||
return True
|
||||
return await self.async_pymodbus_call(
|
||||
unit, address, values, "count", self._client.write_registers
|
||||
)
|
||||
|
|
|
@ -59,6 +59,7 @@ from .const import (
|
|||
MODBUS_DOMAIN,
|
||||
)
|
||||
|
||||
PARALLEL_UPDATES = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -226,9 +227,7 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
|
|||
if state:
|
||||
self._value = state.state
|
||||
|
||||
async_track_time_interval(
|
||||
self.hass, lambda arg: self.update(), self._scan_interval
|
||||
)
|
||||
async_track_time_interval(self.hass, self.async_update, self._scan_interval)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
|
@ -280,19 +279,21 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
|
|||
registers.reverse()
|
||||
return registers
|
||||
|
||||
def update(self):
|
||||
async def async_update(self, now=None):
|
||||
"""Update the state of the sensor."""
|
||||
# remark "now" is a dummy parameter to avoid problems with
|
||||
# async_track_time_interval
|
||||
if self._register_type == CALL_TYPE_REGISTER_INPUT:
|
||||
result = self._hub.read_input_registers(
|
||||
result = await self._hub.async_read_input_registers(
|
||||
self._slave, self._register, self._count
|
||||
)
|
||||
else:
|
||||
result = self._hub.read_holding_registers(
|
||||
result = await self._hub.async_read_holding_registers(
|
||||
self._slave, self._register, self._count
|
||||
)
|
||||
if result is None:
|
||||
self._available = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
registers = self._swap_registers(result.registers)
|
||||
|
@ -332,4 +333,4 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
|
|||
self._value = f"{float(val):.{self._precision}f}"
|
||||
|
||||
self._available = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -34,6 +34,7 @@ from .const import (
|
|||
)
|
||||
from .modbus import ModbusHub
|
||||
|
||||
PARALLEL_UPDATES = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -62,11 +63,11 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
|
|||
self._scan_interval = timedelta(seconds=config[CONF_SCAN_INTERVAL])
|
||||
self._address = config[CONF_ADDRESS]
|
||||
if config[CONF_WRITE_TYPE] == CALL_TYPE_COIL:
|
||||
self._write_func = self._hub.write_coil
|
||||
self._async_write_func = self._hub.async_write_coil
|
||||
self._command_on = 0x01
|
||||
self._command_off = 0x00
|
||||
else:
|
||||
self._write_func = self._hub.write_register
|
||||
self._async_write_func = self._hub.async_write_register
|
||||
self._command_on = config[CONF_COMMAND_ON]
|
||||
self._command_off = config[CONF_COMMAND_OFF]
|
||||
if CONF_VERIFY in config:
|
||||
|
@ -83,13 +84,13 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
|
|||
self._state_off = config[CONF_VERIFY].get(CONF_STATE_OFF, self._command_off)
|
||||
|
||||
if self._verify_type == CALL_TYPE_REGISTER_HOLDING:
|
||||
self._read_func = self._hub.read_holding_registers
|
||||
self._async_read_func = self._hub.async_read_holding_registers
|
||||
elif self._verify_type == CALL_TYPE_DISCRETE:
|
||||
self._read_func = self._hub.read_discrete_inputs
|
||||
self._async_read_func = self._hub.async_read_discrete_inputs
|
||||
elif self._verify_type == CALL_TYPE_REGISTER_INPUT:
|
||||
self._read_func = self._hub.read_input_registers
|
||||
self._async_read_func = self._hub.async_read_input_registers
|
||||
else: # self._verify_type == CALL_TYPE_COIL:
|
||||
self._read_func = self._hub.read_coils
|
||||
self._async_read_func = self._hub.async_read_coils
|
||||
else:
|
||||
self._verify_active = False
|
||||
|
||||
|
@ -99,9 +100,7 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
|
|||
if state:
|
||||
self._is_on = state.state == STATE_ON
|
||||
|
||||
async_track_time_interval(
|
||||
self.hass, lambda arg: self.update(), self._scan_interval
|
||||
)
|
||||
async_track_time_interval(self.hass, self.async_update, self._scan_interval)
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
|
@ -123,46 +122,52 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
|
|||
"""Return True if entity is available."""
|
||||
return self._available
|
||||
|
||||
def turn_on(self, **kwargs):
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Set switch on."""
|
||||
|
||||
result = self._write_func(self._slave, self._address, self._command_on)
|
||||
result = await self._async_write_func(
|
||||
self._slave, self._address, self._command_on
|
||||
)
|
||||
if result is False:
|
||||
self._available = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
else:
|
||||
self._available = True
|
||||
if self._verify_active:
|
||||
self.update()
|
||||
self.async_update()
|
||||
else:
|
||||
self._is_on = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def turn_off(self, **kwargs):
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Set switch off."""
|
||||
result = self._write_func(self._slave, self._address, self._command_off)
|
||||
result = await self._async_write_func(
|
||||
self._slave, self._address, self._command_off
|
||||
)
|
||||
if result is False:
|
||||
self._available = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
else:
|
||||
self._available = True
|
||||
if self._verify_active:
|
||||
self.update()
|
||||
self.async_update()
|
||||
else:
|
||||
self._is_on = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def update(self):
|
||||
async def async_update(self, now=None):
|
||||
"""Update the entity state."""
|
||||
# remark "now" is a dummy parameter to avoid problems with
|
||||
# async_track_time_interval
|
||||
if not self._verify_active:
|
||||
self._available = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
result = self._read_func(self._slave, self._verify_address, 1)
|
||||
result = await self._async_read_func(self._slave, self._verify_address, 1)
|
||||
if result is None:
|
||||
self._available = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._available = True
|
||||
|
@ -182,4 +187,4 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
|
|||
self._verify_address,
|
||||
value,
|
||||
)
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -480,11 +480,13 @@ async def test_pymodbus_connect_fail(hass, caplog, mock_pymodbus):
|
|||
|
||||
|
||||
async def test_delay(hass, mock_pymodbus):
|
||||
"""Run test for different read."""
|
||||
"""Run test for startup delay."""
|
||||
|
||||
# the purpose of this test is to test startup delay
|
||||
# We "hijiack" binary_sensor and sensor in order
|
||||
# to make a proper blackbox test.
|
||||
# We "hijiack" a binary_sensor to make a proper blackbox test.
|
||||
test_delay = 15
|
||||
test_scan_interval = 5
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}"
|
||||
config = {
|
||||
DOMAIN: [
|
||||
{
|
||||
|
@ -492,101 +494,86 @@ async def test_delay(hass, mock_pymodbus):
|
|||
CONF_HOST: "modbusTestHost",
|
||||
CONF_PORT: 5501,
|
||||
CONF_NAME: TEST_MODBUS_NAME,
|
||||
CONF_DELAY: 15,
|
||||
CONF_DELAY: test_delay,
|
||||
CONF_BINARY_SENSORS: [
|
||||
{
|
||||
CONF_INPUT_TYPE: CALL_TYPE_COIL,
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}_2",
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}",
|
||||
CONF_ADDRESS: 52,
|
||||
CONF_SCAN_INTERVAL: 5,
|
||||
},
|
||||
{
|
||||
CONF_INPUT_TYPE: CALL_TYPE_DISCRETE,
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}_1",
|
||||
CONF_ADDRESS: 51,
|
||||
CONF_SCAN_INTERVAL: 5,
|
||||
},
|
||||
],
|
||||
CONF_SENSORS: [
|
||||
{
|
||||
CONF_INPUT_TYPE: CALL_TYPE_REGISTER_HOLDING,
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}_3",
|
||||
CONF_ADDRESS: 53,
|
||||
CONF_SCAN_INTERVAL: 5,
|
||||
},
|
||||
{
|
||||
CONF_INPUT_TYPE: CALL_TYPE_REGISTER_INPUT,
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}_4",
|
||||
CONF_ADDRESS: 54,
|
||||
CONF_SCAN_INTERVAL: 5,
|
||||
CONF_SCAN_INTERVAL: test_scan_interval,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_pymodbus.read_coils.return_value = ReadResult([0x01])
|
||||
mock_pymodbus.read_discrete_inputs.return_value = ReadResult([0x01])
|
||||
mock_pymodbus.read_holding_registers.return_value = ReadResult([7])
|
||||
mock_pymodbus.read_input_registers.return_value = ReadResult([7])
|
||||
now = dt_util.utcnow()
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
assert await async_setup_component(hass, DOMAIN, config) is True
|
||||
await hass.async_block_till_done()
|
||||
|
||||
now = now + timedelta(seconds=10)
|
||||
# pass first scan_interval
|
||||
start_time = now
|
||||
now = now + timedelta(seconds=(test_scan_interval + 1))
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
async_fire_time_changed(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Check states
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_1"
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_2"
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_3"
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_4"
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
|
||||
mock_pymodbus.reset_mock()
|
||||
data = {
|
||||
ATTR_HUB: TEST_MODBUS_NAME,
|
||||
ATTR_UNIT: 17,
|
||||
ATTR_ADDRESS: 16,
|
||||
ATTR_STATE: False,
|
||||
stop_time = start_time + timedelta(seconds=(test_delay + 1))
|
||||
step_timedelta = timedelta(seconds=1)
|
||||
while now < stop_time:
|
||||
now = now + step_timedelta
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
async_fire_time_changed(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
now = now + step_timedelta + timedelta(seconds=2)
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
async_fire_time_changed(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get(entity_id).state == STATE_ON
|
||||
|
||||
|
||||
async def test_thread_lock(hass, mock_pymodbus):
|
||||
"""Run test for block of threads."""
|
||||
|
||||
# the purpose of this test is to test the threads are not being blocked
|
||||
# We "hijiack" a binary_sensor to make a proper blackbox test.
|
||||
test_scan_interval = 5
|
||||
sensors = []
|
||||
for i in range(200):
|
||||
sensors.append(
|
||||
{
|
||||
CONF_INPUT_TYPE: CALL_TYPE_COIL,
|
||||
CONF_NAME: f"{TEST_SENSOR_NAME}_{i}",
|
||||
CONF_ADDRESS: 52 + i,
|
||||
CONF_SCAN_INTERVAL: test_scan_interval,
|
||||
}
|
||||
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True)
|
||||
assert not mock_pymodbus.write_coil.called
|
||||
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True)
|
||||
assert not mock_pymodbus.write_coil.called
|
||||
data[ATTR_STATE] = [True, False, True]
|
||||
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True)
|
||||
assert not mock_pymodbus.write_coils.called
|
||||
|
||||
del data[ATTR_STATE]
|
||||
data[ATTR_VALUE] = 15
|
||||
await hass.services.async_call(DOMAIN, SERVICE_WRITE_REGISTER, data, blocking=True)
|
||||
assert not mock_pymodbus.write_register.called
|
||||
data[ATTR_VALUE] = [1, 2, 3]
|
||||
await hass.services.async_call(DOMAIN, SERVICE_WRITE_REGISTER, data, blocking=True)
|
||||
assert not mock_pymodbus.write_registers.called
|
||||
|
||||
# 2 times fire_changed is needed to secure "normal" update is called.
|
||||
now = now + timedelta(seconds=6)
|
||||
)
|
||||
config = {
|
||||
DOMAIN: [
|
||||
{
|
||||
CONF_TYPE: "tcp",
|
||||
CONF_HOST: "modbusTestHost",
|
||||
CONF_PORT: 5501,
|
||||
CONF_NAME: TEST_MODBUS_NAME,
|
||||
CONF_BINARY_SENSORS: sensors,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_pymodbus.read_coils.return_value = ReadResult([0x01])
|
||||
now = dt_util.utcnow()
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
assert await async_setup_component(hass, DOMAIN, config) is True
|
||||
await hass.async_block_till_done()
|
||||
stop_time = now + timedelta(seconds=10)
|
||||
step_timedelta = timedelta(seconds=1)
|
||||
while now < stop_time:
|
||||
now = now + step_timedelta
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
async_fire_time_changed(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
now = now + timedelta(seconds=10)
|
||||
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
|
||||
async_fire_time_changed(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Check states
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_1"
|
||||
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_2"
|
||||
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_3"
|
||||
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_4"
|
||||
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
for i in range(200):
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_{i}"
|
||||
assert hass.states.get(entity_id).state == STATE_ON
|
||||
|
|
Loading…
Reference in New Issue