Cache matching MQTT subscriptions (#41433)

pull/41450/head^2
Erik Montnemery 2020-10-08 08:52:23 +02:00 committed by GitHub
parent 85603dcd08
commit 392d5c673a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View File

@ -1,6 +1,6 @@
"""Support for MQTT message handling."""
import asyncio
from functools import partial, wraps
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
import json
@ -842,6 +842,7 @@ class MQTT:
topic, _matcher_for_topic(topic), msg_callback, qos, encoding
)
self.subscriptions.append(subscription)
self._matching_subscriptions.cache_clear()
# Only subscribe if currently connected.
if self.connected:
@ -854,6 +855,7 @@ class MQTT:
if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription)
self._matching_subscriptions.cache_clear()
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
@ -944,6 +946,14 @@ class MQTT:
"""Message received callback."""
self.hass.add_job(self._mqtt_handle_message, msg)
@lru_cache(2048)
def _matching_subscriptions(self, topic):
subscriptions = []
for subscription in self.subscriptions:
if subscription.matcher(topic):
subscriptions.append(subscription)
return subscriptions
@callback
def _mqtt_handle_message(self, msg) -> None:
_LOGGER.debug(
@ -954,9 +964,9 @@ class MQTT:
)
timestamp = dt_util.utcnow()
for subscription in self.subscriptions:
if not subscription.matcher(msg.topic):
continue
subscriptions = self._matching_subscriptions(msg.topic)
for subscription in subscriptions:
payload: SubscribePayloadType = msg.payload
if subscription.encoding is not None:

View File

@ -384,9 +384,13 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
assert result
await hass.async_block_till_done()
# Workaround: asynctest==0.13 fails on @functools.lru_cache
spec = dir(hass.data["mqtt"])
spec.remove("_matching_subscriptions")
mqtt_component_mock = MagicMock(
return_value=hass.data["mqtt"],
spec_set=hass.data["mqtt"],
spec_set=spec,
wraps=hass.data["mqtt"],
)
mqtt_component_mock._mqttc = mqtt_client_mock