From 9be829ba1f5e7e1c2f080079efb6ee6322f84291 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 25 May 2024 11:34:24 -1000 Subject: [PATCH] Make mqtt internal subscription a normal function (#118092) Co-authored-by: Jan Bouwhuis --- homeassistant/components/mqtt/__init__.py | 5 +- .../components/mqtt/alarm_control_panel.py | 2 +- .../components/mqtt/binary_sensor.py | 2 +- homeassistant/components/mqtt/camera.py | 2 +- homeassistant/components/mqtt/client.py | 71 +++++++++++-------- homeassistant/components/mqtt/climate.py | 2 +- homeassistant/components/mqtt/cover.py | 2 +- .../components/mqtt/device_tracker.py | 2 +- homeassistant/components/mqtt/event.py | 2 +- homeassistant/components/mqtt/fan.py | 2 +- homeassistant/components/mqtt/humidifier.py | 2 +- homeassistant/components/mqtt/image.py | 2 +- homeassistant/components/mqtt/lawn_mower.py | 2 +- .../components/mqtt/light/schema_basic.py | 2 +- .../components/mqtt/light/schema_json.py | 2 +- .../components/mqtt/light/schema_template.py | 2 +- homeassistant/components/mqtt/lock.py | 2 +- homeassistant/components/mqtt/mixins.py | 20 +++--- homeassistant/components/mqtt/number.py | 2 +- homeassistant/components/mqtt/select.py | 2 +- homeassistant/components/mqtt/sensor.py | 2 +- homeassistant/components/mqtt/siren.py | 2 +- homeassistant/components/mqtt/subscription.py | 54 +++++++++----- homeassistant/components/mqtt/switch.py | 2 +- homeassistant/components/mqtt/tag.py | 2 +- homeassistant/components/mqtt/text.py | 2 +- homeassistant/components/mqtt/update.py | 2 +- homeassistant/components/mqtt/vacuum.py | 2 +- homeassistant/components/mqtt/valve.py | 2 +- tests/components/mqtt/test_init.py | 23 +++++- 30 files changed, 140 insertions(+), 83 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 3391312bdd0..39e2660ca03 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -39,6 +39,7 @@ from .client import ( # noqa: F401 MQTT, async_publish, async_subscribe, + async_subscribe_internal, publish, subscribe, ) @@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: def collect_msg(msg: ReceiveMessage) -> None: messages.append((msg.topic, str(msg.payload).replace("\n", ""))) - unsub = await async_subscribe(hass, call.data["topic"], collect_msg) + unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg) def write_dump() -> None: with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp: @@ -459,7 +460,7 @@ async def websocket_subscribe( # Perform UTF-8 decoding directly in callback routine qos: int = msg.get("qos", DEFAULT_QOS) - connection.subscriptions[msg["id"]] = await async_subscribe( + connection.subscriptions[msg["id"]] = async_subscribe_internal( hass, msg["topic"], forward_messages, encoding=None, qos=qos ) diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index e341d54e349..fe6650cbd0f 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_alarm_disarm(self, code: str | None = None) -> None: """Send disarm command. diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index ce772855e78..61e5074378d 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @callback def _value_is_expired(self, *_: Any) -> None: diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index f8ec099a295..2c6346f5794 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -130,7 +130,7 @@ class MqttCamera(MqttEntity, Camera): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_camera_image( self, width: int | None = None, height: int | None = None diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 0e9f7f06e21..16db9a45b58 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -191,13 +191,25 @@ async def async_subscribe( Call the return value to unsubscribe. """ - if not mqtt_config_entry_enabled(hass): - raise HomeAssistantError( - f"Cannot subscribe to topic '{topic}', MQTT is not enabled", - translation_key="mqtt_not_setup_cannot_subscribe", - translation_domain=DOMAIN, - translation_placeholders={"topic": topic}, - ) + return async_subscribe_internal(hass, topic, msg_callback, qos, encoding) + + +@callback +def async_subscribe_internal( + hass: HomeAssistant, + topic: str, + msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], + qos: int = DEFAULT_QOS, + encoding: str | None = DEFAULT_ENCODING, +) -> CALLBACK_TYPE: + """Subscribe to an MQTT topic. + + This function is internal to the MQTT integration + and may change at any time. It should not be considered + a stable API. + + Call the return value to unsubscribe. + """ try: mqtt_data = hass.data[DATA_MQTT] except KeyError as exc: @@ -208,12 +220,15 @@ async def async_subscribe( translation_domain=DOMAIN, translation_placeholders={"topic": topic}, ) from exc - return await mqtt_data.client.async_subscribe( - topic, - msg_callback, - qos, - encoding, - ) + client = mqtt_data.client + if not client.connected and not mqtt_config_entry_enabled(hass): + raise HomeAssistantError( + f"Cannot subscribe to topic '{topic}', MQTT is not enabled", + translation_key="mqtt_not_setup_cannot_subscribe", + translation_domain=DOMAIN, + translation_placeholders={"topic": topic}, + ) + return client.async_subscribe(topic, msg_callback, qos, encoding) @bind_hass @@ -845,17 +860,15 @@ class MQTT: f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe] ) - async def async_subscribe( + @callback + def async_subscribe( self, topic: str, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], qos: int, encoding: str | None = None, ) -> Callable[[], None]: - """Set up a subscription to a topic with the provided qos. - - This method is a coroutine. - """ + """Set up a subscription to a topic with the provided qos.""" if not isinstance(topic, str): raise HomeAssistantError("Topic needs to be a string!") @@ -881,18 +894,18 @@ class MQTT: if self.connected: self._async_queue_subscriptions(((topic, qos),)) - @callback - def async_remove() -> None: - """Remove subscription.""" - self._async_untrack_subscription(subscription) - self._matching_subscriptions.cache_clear() - if subscription in self._retained_topics: - del self._retained_topics[subscription] - # Only unsubscribe if currently connected - if self.connected: - self._async_unsubscribe(topic) + return partial(self._async_remove, subscription) - return async_remove + @callback + def _async_remove(self, subscription: Subscription) -> None: + """Remove subscription.""" + self._async_untrack_subscription(subscription) + self._matching_subscriptions.cache_clear() + if subscription in self._retained_topics: + del self._retained_topics[subscription] + # Only unsubscribe if currently connected + if self.connected: + self._async_unsubscribe(subscription.topic) @callback def _async_unsubscribe(self, topic: str) -> None: diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index b09ee17af68..57f71008ecc 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def _publish(self, topic: str, payload: PublishPayloadType) -> None: if self._topic[topic] is not None: diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index d741f602670..a4c7c1d8b3b 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_open_cover(self, **kwargs: Any) -> None: """Move the cover up. diff --git a/homeassistant/components/mqtt/device_tracker.py b/homeassistant/components/mqtt/device_tracker.py index 9af85d5ab9f..87abba2ac95 100644 --- a/homeassistant/components/mqtt/device_tracker.py +++ b/homeassistant/components/mqtt/device_tracker.py @@ -166,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def latitude(self) -> float | None: diff --git a/homeassistant/components/mqtt/event.py b/homeassistant/components/mqtt/event.py index 0fa82c7e12b..a09579fccef 100644 --- a/homeassistant/components/mqtt/event.py +++ b/homeassistant/components/mqtt/event.py @@ -208,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index 1ee7bc63796..a418131d5c5 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -477,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def is_on(self) -> bool | None: diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index 7956a05d20a..097018f008f 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -447,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_turn_on(self, **kwargs: Any) -> None: """Turn on the entity. diff --git a/homeassistant/components/mqtt/image.py b/homeassistant/components/mqtt/image.py index 3b7834a9876..4fa410c4595 100644 --- a/homeassistant/components/mqtt/image.py +++ b/homeassistant/components/mqtt/image.py @@ -214,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_image(self) -> bytes | None: """Return bytes of image.""" diff --git a/homeassistant/components/mqtt/lawn_mower.py b/homeassistant/components/mqtt/lawn_mower.py index 3ce04ca29d5..2452b511144 100644 --- a/homeassistant/components/mqtt/lawn_mower.py +++ b/homeassistant/components/mqtt/lawn_mower.py @@ -198,7 +198,7 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_state := await self.async_get_last_state() diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index 650ca1eff6a..583374c8d20 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -627,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() def restore_state( diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index 14e477d0c35..f6dec17f8f3 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -528,7 +528,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() if self._optimistic and last_state: diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 647bf6df401..193b4d23931 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -288,7 +288,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() if self._optimistic and last_state: diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index 33d25b168a8..52c2bea2cc3 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -243,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_lock(self, **kwargs: Any) -> None: """Lock the device. diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index f1fb0de6f4e..0331b49c2a6 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -114,7 +114,7 @@ from .models import ( from .subscription import ( EntitySubscription, async_prepare_subscribe_topics, - async_subscribe_topics, + async_subscribe_topics_internal, async_unsubscribe_topics, ) from .util import mqtt_config_entry_enabled @@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity): """Subscribe MQTT events.""" await super().async_added_to_hass() self._attributes_prepare_subscribe_topics() - await self._attributes_subscribe_topics() + self._attributes_subscribe_topics() def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" @@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity): async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" - await self._attributes_subscribe_topics() + self._attributes_subscribe_topics() def _attributes_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity): }, ) - async def _attributes_subscribe_topics(self) -> None: + @callback + def _attributes_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await async_subscribe_topics(self.hass, self._attributes_sub_state) + async_subscribe_topics_internal(self.hass, self._attributes_sub_state) async def async_will_remove_from_hass(self) -> None: """Unsubscribe when removed.""" @@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity): """Subscribe MQTT events.""" await super().async_added_to_hass() self._availability_prepare_subscribe_topics() - await self._availability_subscribe_topics() + self._availability_subscribe_topics() self.async_on_remove( async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) ) @@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity): async def availability_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" - await self._availability_subscribe_topics() + self._availability_subscribe_topics() def _availability_setup_from_config(self, config: ConfigType) -> None: """(Re)Setup.""" @@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity): self._available[topic] = False self._available_latest = False - async def _availability_subscribe_topics(self) -> None: + @callback + def _availability_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await async_subscribe_topics(self.hass, self._availability_sub_state) + async_subscribe_topics_internal(self.hass, self._availability_sub_state) @callback def async_mqtt_connect(self) -> None: diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index f381087bd37..17e7cfe69e0 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -220,7 +220,7 @@ class MqttNumber(MqttEntity, RestoreNumber): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_number_data := await self.async_get_last_number_data() diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index f37a2b1e231..a2814055a7c 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -160,7 +160,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_state := await self.async_get_last_state() diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index d37da597ffb..c8fe932ed71 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @callback def _value_is_expired(self, *_: datetime) -> None: diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index 5920efbc3c1..06cb2677c09 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -288,7 +288,7 @@ class MqttSiren(MqttEntity, SirenEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def extra_state_attributes(self) -> dict[str, Any] | None: diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index d0dc98484b3..9e3ea21222f 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -2,14 +2,15 @@ from __future__ import annotations -from collections.abc import Callable, Coroutine +from collections.abc import Callable from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Any -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback -from .. import mqtt from . import debug_info +from .client import async_subscribe_internal from .const import DEFAULT_QOS from .models import MessageCallbackType @@ -21,7 +22,7 @@ class EntitySubscription: hass: HomeAssistant topic: str | None message_callback: MessageCallbackType - subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None + should_subscribe: bool | None unsubscribe_callback: Callable[[], None] | None qos: int = 0 encoding: str = "utf-8" @@ -53,15 +54,16 @@ class EntitySubscription: self.hass, self.message_callback, self.topic, self.entity_id ) - self.subscribe_task = mqtt.async_subscribe( - hass, self.topic, self.message_callback, self.qos, self.encoding - ) + self.should_subscribe = True - async def subscribe(self) -> None: + @callback + def subscribe(self) -> None: """Subscribe to a topic.""" - if not self.subscribe_task: + if not self.should_subscribe or not self.topic: return - self.unsubscribe_callback = await self.subscribe_task + self.unsubscribe_callback = async_subscribe_internal( + self.hass, self.topic, self.message_callback, self.qos, self.encoding + ) def _should_resubscribe(self, other: EntitySubscription | None) -> bool: """Check if we should re-subscribe to the topic using the old state.""" @@ -79,6 +81,7 @@ class EntitySubscription: ) +@callback def async_prepare_subscribe_topics( hass: HomeAssistant, new_state: dict[str, EntitySubscription] | None, @@ -107,7 +110,7 @@ def async_prepare_subscribe_topics( qos=value.get("qos", DEFAULT_QOS), encoding=value.get("encoding", "utf-8"), hass=hass, - subscribe_task=None, + should_subscribe=None, entity_id=value.get("entity_id", None), ) # Get the current subscription state @@ -135,12 +138,29 @@ async def async_subscribe_topics( sub_state: dict[str, EntitySubscription], ) -> None: """(Re)Subscribe to a set of MQTT topics.""" + async_subscribe_topics_internal(hass, sub_state) + + +@callback +def async_subscribe_topics_internal( + hass: HomeAssistant, + sub_state: dict[str, EntitySubscription], +) -> None: + """(Re)Subscribe to a set of MQTT topics. + + This function is internal to the MQTT integration and should not be called + from outside the integration. + """ for sub in sub_state.values(): - await sub.subscribe() + sub.subscribe() -def async_unsubscribe_topics( - hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None -) -> dict[str, EntitySubscription]: - """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" - return async_prepare_subscribe_topics(hass, sub_state, {}) +if TYPE_CHECKING: + + def async_unsubscribe_topics( + hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None + ) -> dict[str, EntitySubscription]: + """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" + + +async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={}) diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 8289b11adca..9f266a0e9ab 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._optimistic and (last_state := await self.async_get_last_state()): self._attr_is_on = last_state.state == STATE_ON diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index 4ecf0862827..55f7e775ae9 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin): } }, ) - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_tear_down(self) -> None: """Cleanup tag scanner.""" diff --git a/homeassistant/components/mqtt/text.py b/homeassistant/components/mqtt/text.py index c563195e6e0..abced8b8744 100644 --- a/homeassistant/components/mqtt/text.py +++ b/homeassistant/components/mqtt/text.py @@ -198,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_set_value(self, value: str) -> None: """Change the text.""" diff --git a/homeassistant/components/mqtt/update.py b/homeassistant/components/mqtt/update.py index 9b6ee901eaf..ee29601e585 100644 --- a/homeassistant/components/mqtt/update.py +++ b/homeassistant/components/mqtt/update.py @@ -257,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_install( self, version: str | None, backup: bool, **kwargs: Any diff --git a/homeassistant/components/mqtt/vacuum.py b/homeassistant/components/mqtt/vacuum.py index b41242b4855..5c8c2fd2ba5 100644 --- a/homeassistant/components/mqtt/vacuum.py +++ b/homeassistant/components/mqtt/vacuum.py @@ -353,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def _async_publish_command(self, feature: VacuumEntityFeature) -> None: """Publish a command.""" diff --git a/homeassistant/components/mqtt/valve.py b/homeassistant/components/mqtt/valve.py index 89a60eef852..2536d9beb40 100644 --- a/homeassistant/components/mqtt/valve.py +++ b/homeassistant/components/mqtt/valve.py @@ -371,7 +371,7 @@ class MqttValve(MqttEntity, ValveEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_open_valve(self) -> None: """Move the valve up. diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 57056819784..9421cddc6a2 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1051,6 +1051,27 @@ async def test_subscribe_topic_not_initialize( await mqtt.async_subscribe(hass, "test-topic", record_calls) +async def test_subscribe_mqtt_config_entry_disabled( + hass: HomeAssistant, mqtt_mock: MqttMockHAClient +) -> None: + """Test the subscription of a topic when MQTT config entry is disabled.""" + mqtt_mock.connected = True + + mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + assert mqtt_config_entry.state is ConfigEntryState.LOADED + + assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id) + assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED + + await hass.config_entries.async_set_disabled_by( + mqtt_config_entry.entry_id, ConfigEntryDisabler.USER + ) + mqtt_mock.connected = False + + with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"): + await mqtt.async_subscribe(hass, "test-topic", record_calls) + + @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2) async def test_subscribe_and_resubscribe( @@ -3824,7 +3845,7 @@ async def test_unload_config_entry( async def test_publish_or_subscribe_without_valid_config_entry( hass: HomeAssistant, record_calls: MessageCallbackType ) -> None: - """Test internal publish function with bas use cases.""" + """Test internal publish function with bad use cases.""" with pytest.raises(HomeAssistantError): await mqtt.async_publish( hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None