Refactor modbus polling to prevent dupe updates and memory leak (#136211)

pull/136272/head^2
J. Nick Koston 2025-01-23 08:52:40 -10:00 committed by GitHub
parent 2466df2b78
commit 9d83bbfec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 38 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from datetime import datetime
import logging
from typing import Any
@ -104,7 +103,7 @@ class ModbusBinarySensor(BasePlatform, RestoreEntity, BinarySensorEntity):
if state := await self.async_get_last_state():
self._attr_is_on = state.state == STATE_ON
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Update the state of the sensor."""
# do not allow multiple active calls to the same platform
@ -126,7 +125,6 @@ class ModbusBinarySensor(BasePlatform, RestoreEntity, BinarySensorEntity):
self._result = result.registers
self._attr_is_on = bool(self._result[0] & 1)
self.async_write_ha_state()
if self._coordinator:
self._coordinator.async_set_updated_data(self._result)
@ -159,7 +157,6 @@ class SlaveSensor(
"""Handle entity which will be added."""
if state := await self.async_get_last_state():
self._attr_is_on = state.state == STATE_ON
self.async_write_ha_state()
await super().async_added_to_hass()
@callback

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from datetime import datetime
import logging
import struct
from typing import Any, cast
@ -313,7 +312,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
)
break
await self.async_update()
await self._async_update_write_state()
async def async_set_fan_mode(self, fan_mode: str) -> None:
"""Set new target fan mode."""
@ -335,7 +334,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTER,
)
await self.async_update()
await self._async_update_write_state()
async def async_set_swing_mode(self, swing_mode: str) -> None:
"""Set new target swing mode."""
@ -358,7 +357,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTER,
)
break
await self.async_update()
await self._async_update_write_state()
async def async_set_temperature(self, **kwargs: Any) -> None:
"""Set new target temperature."""
@ -413,9 +412,9 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTERS,
)
self._attr_available = result is not None
await self.async_update()
await self._async_update_write_state()
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Update Target & Current Temperature."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
@ -490,8 +489,6 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
if onoff == self._hvac_off_value:
self._attr_hvac_mode = HVACMode.OFF
self.async_write_ha_state()
async def _async_read_register(
self, register_type: str, register: int, raw: bool | None = False
) -> float | None:

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from homeassistant.components.cover import CoverEntity, CoverEntityFeature, CoverState
@ -117,7 +116,7 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
self._slave, self._write_address, self._state_open, self._write_type
)
self._attr_available = result is not None
await self.async_update()
await self._async_update_write_state()
async def async_close_cover(self, **kwargs: Any) -> None:
"""Close cover."""
@ -125,9 +124,9 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
self._slave, self._write_address, self._state_closed, self._write_type
)
self._attr_available = result is not None
await self.async_update()
await self._async_update_write_state()
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Update the state of the cover."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
@ -136,11 +135,9 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
)
if result is None:
self._attr_available = False
self.async_write_ha_state()
return
self._attr_available = True
if self._input_type == CALL_TYPE_COIL:
self._set_attr_state(bool(result.bits[0] & 1))
else:
self._set_attr_state(int(result.registers[0]))
self.async_write_ha_state()

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable
from datetime import datetime, timedelta
import logging
@ -107,37 +108,73 @@ class BasePlatform(Entity):
self._max_value = get_optional_numeric_config(CONF_MAX_VALUE)
self._nan_value = entry.get(CONF_NAN_VALUE)
self._zero_suppress = get_optional_numeric_config(CONF_ZERO_SUPPRESS)
self._update_lock = asyncio.Lock()
@abstractmethod
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Virtual function to be overwritten."""
async def async_update(self, now: datetime | None = None) -> None:
"""Update the entity state."""
async with self._update_lock:
await self._async_update()
async def _async_update_write_state(self) -> None:
"""Update the entity state and write it to the state machine."""
await self.async_update()
self.async_write_ha_state()
async def _async_update_if_not_in_progress(
self, now: datetime | None = None
) -> None:
"""Update the entity state if not already in progress."""
if self._update_lock.locked():
_LOGGER.debug("Update for entity %s is already in progress", self.name)
return
await self._async_update_write_state()
@callback
def async_run(self) -> None:
"""Remote start entity."""
self.async_hold(update=False)
self._cancel_call = async_call_later(
self.hass, timedelta(milliseconds=100), self.async_update
)
self._async_cancel_update_polling()
self._async_schedule_future_update(0.1)
if self._scan_interval > 0:
self._cancel_timer = async_track_time_interval(
self.hass, self.async_update, timedelta(seconds=self._scan_interval)
self.hass,
self._async_update_if_not_in_progress,
timedelta(seconds=self._scan_interval),
)
self._attr_available = True
self.async_write_ha_state()
@callback
def async_hold(self, update: bool = True) -> None:
"""Remote stop entity."""
def _async_schedule_future_update(self, delay: float) -> None:
"""Schedule an update in the future."""
self._async_cancel_future_pending_update()
self._cancel_call = async_call_later(
self.hass, delay, self._async_update_if_not_in_progress
)
@callback
def _async_cancel_future_pending_update(self) -> None:
"""Cancel a future pending update."""
if self._cancel_call:
self._cancel_call()
self._cancel_call = None
def _async_cancel_update_polling(self) -> None:
"""Cancel the polling."""
if self._cancel_timer:
self._cancel_timer()
self._cancel_timer = None
if update:
self._attr_available = False
self.async_write_ha_state()
@callback
def async_hold(self) -> None:
"""Remote stop entity."""
self._async_cancel_future_pending_update()
self._async_cancel_update_polling()
self._attr_available = False
self.async_write_ha_state()
async def async_base_added_to_hass(self) -> None:
"""Handle entity which will be added."""
@ -312,6 +349,7 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._attr_is_on = True
elif state.state == STATE_OFF:
self._attr_is_on = False
await super().async_added_to_hass()
async def async_turn(self, command: int) -> None:
"""Evaluate switch result."""
@ -330,21 +368,21 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
return
if self._verify_delay:
async_call_later(self.hass, self._verify_delay, self.async_update)
else:
await self.async_update()
self._async_schedule_future_update(self._verify_delay)
return
await self._async_update_write_state()
async def async_turn_off(self, **kwargs: Any) -> None:
"""Set switch off."""
await self.async_turn(self._command_off)
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Update the entity state."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
if not self._verify_active:
self._attr_available = True
self.async_write_ha_state()
return
# do not allow multiple active calls to the same platform
@ -357,7 +395,6 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._call_active = False
if result is None:
self._attr_available = False
self.async_write_ha_state()
return
self._attr_available = True
@ -379,4 +416,3 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._verify_address,
value,
)
self.async_write_ha_state()

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from datetime import datetime
import logging
from typing import Any
@ -106,7 +105,7 @@ class ModbusRegisterSensor(BaseStructPlatform, RestoreSensor, SensorEntity):
if state:
self._attr_native_value = state.native_value
async def async_update(self, now: datetime | None = None) -> None:
async def _async_update(self) -> None:
"""Update the state of the sensor."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval