Cache matching MQTT subscriptions (#41433)
parent
85603dcd08
commit
392d5c673a
|
@ -1,6 +1,6 @@
|
|||
"""Support for MQTT message handling."""
|
||||
import asyncio
|
||||
from functools import partial, wraps
|
||||
from functools import lru_cache, partial, wraps
|
||||
import inspect
|
||||
from itertools import groupby
|
||||
import json
|
||||
|
@ -842,6 +842,7 @@ class MQTT:
|
|||
topic, _matcher_for_topic(topic), msg_callback, qos, encoding
|
||||
)
|
||||
self.subscriptions.append(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
# Only subscribe if currently connected.
|
||||
if self.connected:
|
||||
|
@ -854,6 +855,7 @@ class MQTT:
|
|||
if subscription not in self.subscriptions:
|
||||
raise HomeAssistantError("Can't remove subscription twice")
|
||||
self.subscriptions.remove(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
if any(other.topic == topic for other in self.subscriptions):
|
||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||
|
@ -944,6 +946,14 @@ class MQTT:
|
|||
"""Message received callback."""
|
||||
self.hass.add_job(self._mqtt_handle_message, msg)
|
||||
|
||||
@lru_cache(2048)
|
||||
def _matching_subscriptions(self, topic):
|
||||
subscriptions = []
|
||||
for subscription in self.subscriptions:
|
||||
if subscription.matcher(topic):
|
||||
subscriptions.append(subscription)
|
||||
return subscriptions
|
||||
|
||||
@callback
|
||||
def _mqtt_handle_message(self, msg) -> None:
|
||||
_LOGGER.debug(
|
||||
|
@ -954,9 +964,9 @@ class MQTT:
|
|||
)
|
||||
timestamp = dt_util.utcnow()
|
||||
|
||||
for subscription in self.subscriptions:
|
||||
if not subscription.matcher(msg.topic):
|
||||
continue
|
||||
subscriptions = self._matching_subscriptions(msg.topic)
|
||||
|
||||
for subscription in subscriptions:
|
||||
|
||||
payload: SubscribePayloadType = msg.payload
|
||||
if subscription.encoding is not None:
|
||||
|
|
|
@ -384,9 +384,13 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
|
|||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Workaround: asynctest==0.13 fails on @functools.lru_cache
|
||||
spec = dir(hass.data["mqtt"])
|
||||
spec.remove("_matching_subscriptions")
|
||||
|
||||
mqtt_component_mock = MagicMock(
|
||||
return_value=hass.data["mqtt"],
|
||||
spec_set=hass.data["mqtt"],
|
||||
spec_set=spec,
|
||||
wraps=hass.data["mqtt"],
|
||||
)
|
||||
mqtt_component_mock._mqttc = mqtt_client_mock
|
||||
|
|
Loading…
Reference in New Issue