Refactor modbus polling to prevent dupe updates and memory leak (#136211)

pull/136272/head^2
J. Nick Koston 2025-01-23 08:52:40 -10:00 committed by GitHub
parent 2466df2b78
commit 9d83bbfec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 38 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
import logging import logging
from typing import Any from typing import Any
@ -104,7 +103,7 @@ class ModbusBinarySensor(BasePlatform, RestoreEntity, BinarySensorEntity):
if state := await self.async_get_last_state(): if state := await self.async_get_last_state():
self._attr_is_on = state.state == STATE_ON self._attr_is_on = state.state == STATE_ON
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Update the state of the sensor.""" """Update the state of the sensor."""
# do not allow multiple active calls to the same platform # do not allow multiple active calls to the same platform
@ -126,7 +125,6 @@ class ModbusBinarySensor(BasePlatform, RestoreEntity, BinarySensorEntity):
self._result = result.registers self._result = result.registers
self._attr_is_on = bool(self._result[0] & 1) self._attr_is_on = bool(self._result[0] & 1)
self.async_write_ha_state()
if self._coordinator: if self._coordinator:
self._coordinator.async_set_updated_data(self._result) self._coordinator.async_set_updated_data(self._result)
@ -159,7 +157,6 @@ class SlaveSensor(
"""Handle entity which will be added.""" """Handle entity which will be added."""
if state := await self.async_get_last_state(): if state := await self.async_get_last_state():
self._attr_is_on = state.state == STATE_ON self._attr_is_on = state.state == STATE_ON
self.async_write_ha_state()
await super().async_added_to_hass() await super().async_added_to_hass()
@callback @callback

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
import logging import logging
import struct import struct
from typing import Any, cast from typing import Any, cast
@ -313,7 +312,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
) )
break break
await self.async_update() await self._async_update_write_state()
async def async_set_fan_mode(self, fan_mode: str) -> None: async def async_set_fan_mode(self, fan_mode: str) -> None:
"""Set new target fan mode.""" """Set new target fan mode."""
@ -335,7 +334,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTER, CALL_TYPE_WRITE_REGISTER,
) )
await self.async_update() await self._async_update_write_state()
async def async_set_swing_mode(self, swing_mode: str) -> None: async def async_set_swing_mode(self, swing_mode: str) -> None:
"""Set new target swing mode.""" """Set new target swing mode."""
@ -358,7 +357,7 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTER, CALL_TYPE_WRITE_REGISTER,
) )
break break
await self.async_update() await self._async_update_write_state()
async def async_set_temperature(self, **kwargs: Any) -> None: async def async_set_temperature(self, **kwargs: Any) -> None:
"""Set new target temperature.""" """Set new target temperature."""
@ -413,9 +412,9 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
CALL_TYPE_WRITE_REGISTERS, CALL_TYPE_WRITE_REGISTERS,
) )
self._attr_available = result is not None self._attr_available = result is not None
await self.async_update() await self._async_update_write_state()
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Update Target & Current Temperature.""" """Update Target & Current Temperature."""
# remark "now" is a dummy parameter to avoid problems with # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval # async_track_time_interval
@ -490,8 +489,6 @@ class ModbusThermostat(BaseStructPlatform, RestoreEntity, ClimateEntity):
if onoff == self._hvac_off_value: if onoff == self._hvac_off_value:
self._attr_hvac_mode = HVACMode.OFF self._attr_hvac_mode = HVACMode.OFF
self.async_write_ha_state()
async def _async_read_register( async def _async_read_register(
self, register_type: str, register: int, raw: bool | None = False self, register_type: str, register: int, raw: bool | None = False
) -> float | None: ) -> float | None:

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from typing import Any from typing import Any
from homeassistant.components.cover import CoverEntity, CoverEntityFeature, CoverState from homeassistant.components.cover import CoverEntity, CoverEntityFeature, CoverState
@ -117,7 +116,7 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
self._slave, self._write_address, self._state_open, self._write_type self._slave, self._write_address, self._state_open, self._write_type
) )
self._attr_available = result is not None self._attr_available = result is not None
await self.async_update() await self._async_update_write_state()
async def async_close_cover(self, **kwargs: Any) -> None: async def async_close_cover(self, **kwargs: Any) -> None:
"""Close cover.""" """Close cover."""
@ -125,9 +124,9 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
self._slave, self._write_address, self._state_closed, self._write_type self._slave, self._write_address, self._state_closed, self._write_type
) )
self._attr_available = result is not None self._attr_available = result is not None
await self.async_update() await self._async_update_write_state()
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Update the state of the cover.""" """Update the state of the cover."""
# remark "now" is a dummy parameter to avoid problems with # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval # async_track_time_interval
@ -136,11 +135,9 @@ class ModbusCover(BasePlatform, CoverEntity, RestoreEntity):
) )
if result is None: if result is None:
self._attr_available = False self._attr_available = False
self.async_write_ha_state()
return return
self._attr_available = True self._attr_available = True
if self._input_type == CALL_TYPE_COIL: if self._input_type == CALL_TYPE_COIL:
self._set_attr_state(bool(result.bits[0] & 1)) self._set_attr_state(bool(result.bits[0] & 1))
else: else:
self._set_attr_state(int(result.registers[0])) self._set_attr_state(int(result.registers[0]))
self.async_write_ha_state()

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
@ -107,35 +108,71 @@ class BasePlatform(Entity):
self._max_value = get_optional_numeric_config(CONF_MAX_VALUE) self._max_value = get_optional_numeric_config(CONF_MAX_VALUE)
self._nan_value = entry.get(CONF_NAN_VALUE) self._nan_value = entry.get(CONF_NAN_VALUE)
self._zero_suppress = get_optional_numeric_config(CONF_ZERO_SUPPRESS) self._zero_suppress = get_optional_numeric_config(CONF_ZERO_SUPPRESS)
self._update_lock = asyncio.Lock()
@abstractmethod @abstractmethod
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Virtual function to be overwritten.""" """Virtual function to be overwritten."""
async def async_update(self, now: datetime | None = None) -> None:
"""Update the entity state."""
async with self._update_lock:
await self._async_update()
async def _async_update_write_state(self) -> None:
"""Update the entity state and write it to the state machine."""
await self.async_update()
self.async_write_ha_state()
async def _async_update_if_not_in_progress(
self, now: datetime | None = None
) -> None:
"""Update the entity state if not already in progress."""
if self._update_lock.locked():
_LOGGER.debug("Update for entity %s is already in progress", self.name)
return
await self._async_update_write_state()
@callback @callback
def async_run(self) -> None: def async_run(self) -> None:
"""Remote start entity.""" """Remote start entity."""
self.async_hold(update=False) self._async_cancel_update_polling()
self._cancel_call = async_call_later( self._async_schedule_future_update(0.1)
self.hass, timedelta(milliseconds=100), self.async_update
)
if self._scan_interval > 0: if self._scan_interval > 0:
self._cancel_timer = async_track_time_interval( self._cancel_timer = async_track_time_interval(
self.hass, self.async_update, timedelta(seconds=self._scan_interval) self.hass,
self._async_update_if_not_in_progress,
timedelta(seconds=self._scan_interval),
) )
self._attr_available = True self._attr_available = True
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def async_hold(self, update: bool = True) -> None: def _async_schedule_future_update(self, delay: float) -> None:
"""Remote stop entity.""" """Schedule an update in the future."""
self._async_cancel_future_pending_update()
self._cancel_call = async_call_later(
self.hass, delay, self._async_update_if_not_in_progress
)
@callback
def _async_cancel_future_pending_update(self) -> None:
"""Cancel a future pending update."""
if self._cancel_call: if self._cancel_call:
self._cancel_call() self._cancel_call()
self._cancel_call = None self._cancel_call = None
def _async_cancel_update_polling(self) -> None:
"""Cancel the polling."""
if self._cancel_timer: if self._cancel_timer:
self._cancel_timer() self._cancel_timer()
self._cancel_timer = None self._cancel_timer = None
if update:
@callback
def async_hold(self) -> None:
"""Remote stop entity."""
self._async_cancel_future_pending_update()
self._async_cancel_update_polling()
self._attr_available = False self._attr_available = False
self.async_write_ha_state() self.async_write_ha_state()
@ -312,6 +349,7 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._attr_is_on = True self._attr_is_on = True
elif state.state == STATE_OFF: elif state.state == STATE_OFF:
self._attr_is_on = False self._attr_is_on = False
await super().async_added_to_hass()
async def async_turn(self, command: int) -> None: async def async_turn(self, command: int) -> None:
"""Evaluate switch result.""" """Evaluate switch result."""
@ -330,21 +368,21 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
return return
if self._verify_delay: if self._verify_delay:
async_call_later(self.hass, self._verify_delay, self.async_update) self._async_schedule_future_update(self._verify_delay)
else: return
await self.async_update()
await self._async_update_write_state()
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Set switch off.""" """Set switch off."""
await self.async_turn(self._command_off) await self.async_turn(self._command_off)
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Update the entity state.""" """Update the entity state."""
# remark "now" is a dummy parameter to avoid problems with # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval # async_track_time_interval
if not self._verify_active: if not self._verify_active:
self._attr_available = True self._attr_available = True
self.async_write_ha_state()
return return
# do not allow multiple active calls to the same platform # do not allow multiple active calls to the same platform
@ -357,7 +395,6 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._call_active = False self._call_active = False
if result is None: if result is None:
self._attr_available = False self._attr_available = False
self.async_write_ha_state()
return return
self._attr_available = True self._attr_available = True
@ -379,4 +416,3 @@ class BaseSwitch(BasePlatform, ToggleEntity, RestoreEntity):
self._verify_address, self._verify_address,
value, value,
) )
self.async_write_ha_state()

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
import logging import logging
from typing import Any from typing import Any
@ -106,7 +105,7 @@ class ModbusRegisterSensor(BaseStructPlatform, RestoreSensor, SensorEntity):
if state: if state:
self._attr_native_value = state.native_value self._attr_native_value = state.native_value
async def async_update(self, now: datetime | None = None) -> None: async def _async_update(self) -> None:
"""Update the state of the sensor.""" """Update the state of the sensor."""
# remark "now" is a dummy parameter to avoid problems with # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval # async_track_time_interval