Fix mqtt timer churn (#117885)

Borrows the same design from homeassistant.helpers.storage to avoid
rescheduling the timer every time async_schedule is called if a timer
is already running.

Instead of the timer fires too early it gets rescheduled for the time
we wanted it. This avoids 1000s of timer add/cancel during startup
pull/117831/head
J. Nick Koston 2024-05-21 15:05:33 -10:00 committed by GitHub
parent 1800a60a6d
commit f429bfa903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 9 deletions

View File

@ -328,6 +328,7 @@ class EnsureJobAfterCooldown:
self._callback = callback_job self._callback = callback_job
self._task: asyncio.Task | None = None self._task: asyncio.Task | None = None
self._timer: asyncio.TimerHandle | None = None self._timer: asyncio.TimerHandle | None = None
self._next_execute_time = 0.0
def set_timeout(self, timeout: float) -> None: def set_timeout(self, timeout: float) -> None:
"""Set a new timeout period.""" """Set a new timeout period."""
@ -371,8 +372,28 @@ class EnsureJobAfterCooldown:
"""Ensure we execute after a cooldown period.""" """Ensure we execute after a cooldown period."""
# We want to reschedule the timer in the future # We want to reschedule the timer in the future
# every time this is called. # every time this is called.
self._async_cancel_timer() next_when = self._loop.time() + self._timeout
self._timer = self._loop.call_later(self._timeout, self.async_execute) if not self._timer:
self._timer = self._loop.call_at(next_when, self._async_timer_reached)
return
if self._timer.when() < next_when:
# Timer already running, set the next execute time
# if it fires too early, it will get rescheduled
self._next_execute_time = next_when
@callback
def _async_timer_reached(self) -> None:
"""Handle timer fire."""
self._timer = None
if self._loop.time() >= self._next_execute_time:
self.async_execute()
return
# Timer fired too early because there were multiple
# calls async_schedule. Reschedule the timer.
self._timer = self._loop.call_at(
self._next_execute_time, self._async_timer_reached
)
async def async_cleanup(self) -> None: async def async_cleanup(self) -> None:
"""Cleanup any pending task.""" """Cleanup any pending task."""

View File

@ -1839,6 +1839,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
record_calls: MessageCallbackType, record_calls: MessageCallbackType,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Test active subscriptions are restored correctly on reconnect.""" """Test active subscriptions are restored correctly on reconnect."""
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
@ -1849,7 +1850,8 @@ async def test_restore_all_active_subscriptions_on_reconnect(
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1)
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown freezer.tick(3)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
# the subscribtion with the highest QoS should survive # the subscribtion with the highest QoS should survive
@ -1865,15 +1867,18 @@ async def test_restore_all_active_subscriptions_on_reconnect(
mqtt_client_mock.on_disconnect(None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_client_mock.on_connect(None, None, None, 0) mqtt_client_mock.on_connect(None, None, None, 0)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown freezer.tick(3)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
expected.append(call([("test/state", 1)])) expected.append(call([("test/state", 1)]))
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown freezer.tick(3)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown freezer.tick(3)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
@ -1889,6 +1894,7 @@ async def test_subscribed_at_highest_qos(
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
record_calls: MessageCallbackType, record_calls: MessageCallbackType,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Test the highest qos as assigned when subscribing to the same topic.""" """Test the highest qos as assigned when subscribing to the same topic."""
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
@ -1897,18 +1903,21 @@ async def test_subscribed_at_highest_qos(
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown freezer.tick(5)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
assert ("test/state", 0) in help_all_subscribe_calls(mqtt_client_mock) assert ("test/state", 0) in help_all_subscribe_calls(mqtt_client_mock)
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown freezer.tick(5)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1)
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown freezer.tick(5)
async_fire_time_changed(hass) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
# the subscribtion with the highest QoS should survive # the subscribtion with the highest QoS should survive
assert help_all_subscribe_calls(mqtt_client_mock) == [("test/state", 2)] assert help_all_subscribe_calls(mqtt_client_mock) == [("test/state", 2)]