Make mqtt internal subscription a normal function (#118092)

Co-authored-by: Jan Bouwhuis <jbouwh@users.noreply.github.com>
pull/118143/head
J. Nick Koston 2024-05-25 11:34:24 -10:00 committed by GitHub
parent ecd48cc447
commit 9be829ba1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 140 additions and 83 deletions

View File

@ -39,6 +39,7 @@ from .client import ( # noqa: F401
MQTT,
async_publish,
async_subscribe,
async_subscribe_internal,
publish,
subscribe,
)
@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
def collect_msg(msg: ReceiveMessage) -> None:
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
def write_dump() -> None:
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
@ -459,7 +460,7 @@ async def websocket_subscribe(
# Perform UTF-8 decoding directly in callback routine
qos: int = msg.get("qos", DEFAULT_QOS)
connection.subscriptions[msg["id"]] = await async_subscribe(
connection.subscriptions[msg["id"]] = async_subscribe_internal(
hass, msg["topic"], forward_messages, encoding=None, qos=qos
)

View File

@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Send disarm command.

View File

@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_: Any) -> None:

View File

@ -130,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_camera_image(
self, width: int | None = None, height: int | None = None

View File

@ -191,13 +191,25 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
@callback
def async_subscribe_internal(
hass: HomeAssistant,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
This function is internal to the MQTT integration
and may change at any time. It should not be considered
a stable API.
Call the return value to unsubscribe.
"""
try:
mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc:
@ -208,12 +220,15 @@ async def async_subscribe(
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
) from exc
return await mqtt_data.client.async_subscribe(
topic,
msg_callback,
qos,
encoding,
)
client = mqtt_data.client
if not client.connected and not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return client.async_subscribe(topic, msg_callback, qos, encoding)
@bind_hass
@ -845,17 +860,15 @@ class MQTT:
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
)
async def async_subscribe(
@callback
def async_subscribe(
self,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int,
encoding: str | None = None,
) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos.
This method is a coroutine.
"""
"""Set up a subscription to a topic with the provided qos."""
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")
@ -881,18 +894,18 @@ class MQTT:
if self.connected:
self._async_queue_subscriptions(((topic, qos),))
@callback
def async_remove() -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(topic)
return partial(self._async_remove, subscription)
return async_remove
@callback
def _async_remove(self, subscription: Subscription) -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(subscription.topic)
@callback
def _async_unsubscribe(self, topic: str) -> None:

View File

@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
if self._topic[topic] is not None:

View File

@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the cover up.

View File

@ -166,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def latitude(self) -> float | None:

View File

@ -208,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)

View File

@ -477,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def is_on(self) -> bool | None:

View File

@ -447,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the entity.

View File

@ -214,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_image(self) -> bytes | None:
"""Return bytes of image."""

View File

@ -198,7 +198,7 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_state := await self.async_get_last_state()

View File

@ -627,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
def restore_state(

View File

@ -528,7 +528,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:

View File

@ -288,7 +288,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:

View File

@ -243,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_lock(self, **kwargs: Any) -> None:
"""Lock the device.

View File

@ -114,7 +114,7 @@ from .models import (
from .subscription import (
EntitySubscription,
async_prepare_subscribe_topics,
async_subscribe_topics,
async_subscribe_topics_internal,
async_unsubscribe_topics,
)
from .util import mqtt_config_entry_enabled
@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics()
self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
await self._attributes_subscribe_topics()
self._attributes_subscribe_topics()
def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
},
)
async def _attributes_subscribe_topics(self) -> None:
@callback
def _attributes_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state)
async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed."""
@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics()
self._availability_subscribe_topics()
self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
)
@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
await self._availability_subscribe_topics()
self._availability_subscribe_topics()
def _availability_setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup."""
@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
self._available[topic] = False
self._available_latest = False
async def _availability_subscribe_topics(self) -> None:
@callback
def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state)
async_subscribe_topics_internal(self.hass, self._availability_sub_state)
@callback
def async_mqtt_connect(self) -> None:

View File

@ -220,7 +220,7 @@ class MqttNumber(MqttEntity, RestoreNumber):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_number_data := await self.async_get_last_number_data()

View File

@ -160,7 +160,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_state := await self.async_get_last_state()

View File

@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_: datetime) -> None:

View File

@ -288,7 +288,7 @@ class MqttSiren(MqttEntity, SirenEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def extra_state_attributes(self) -> dict[str, Any] | None:

View File

@ -2,14 +2,15 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from .. import mqtt
from . import debug_info
from .client import async_subscribe_internal
from .const import DEFAULT_QOS
from .models import MessageCallbackType
@ -21,7 +22,7 @@ class EntitySubscription:
hass: HomeAssistant
topic: str | None
message_callback: MessageCallbackType
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
should_subscribe: bool | None
unsubscribe_callback: Callable[[], None] | None
qos: int = 0
encoding: str = "utf-8"
@ -53,15 +54,16 @@ class EntitySubscription:
self.hass, self.message_callback, self.topic, self.entity_id
)
self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding
)
self.should_subscribe = True
async def subscribe(self) -> None:
@callback
def subscribe(self) -> None:
"""Subscribe to a topic."""
if not self.subscribe_task:
if not self.should_subscribe or not self.topic:
return
self.unsubscribe_callback = await self.subscribe_task
self.unsubscribe_callback = async_subscribe_internal(
self.hass, self.topic, self.message_callback, self.qos, self.encoding
)
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
"""Check if we should re-subscribe to the topic using the old state."""
@ -79,6 +81,7 @@ class EntitySubscription:
)
@callback
def async_prepare_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"),
hass=hass,
subscribe_task=None,
should_subscribe=None,
entity_id=value.get("entity_id", None),
)
# Get the current subscription state
@ -135,12 +138,29 @@ async def async_subscribe_topics(
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics."""
async_subscribe_topics_internal(hass, sub_state)
@callback
def async_subscribe_topics_internal(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics.
This function is internal to the MQTT integration and should not be called
from outside the integration.
"""
for sub in sub_state.values():
await sub.subscribe()
sub.subscribe()
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return async_prepare_subscribe_topics(hass, sub_state, {})
if TYPE_CHECKING:
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})

View File

@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_is_on = last_state.state == STATE_ON

View File

@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
}
},
)
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_tear_down(self) -> None:
"""Cleanup tag scanner."""

View File

@ -198,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_set_value(self, value: str) -> None:
"""Change the text."""

View File

@ -257,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_install(
self, version: str | None, backup: bool, **kwargs: Any

View File

@ -353,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
"""Publish a command."""

View File

@ -371,7 +371,7 @@ class MqttValve(MqttEntity, ValveEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_valve(self) -> None:
"""Move the valve up.

View File

@ -1051,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
await mqtt.async_subscribe(hass, "test-topic", record_calls)
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
) -> None:
"""Test the subscription of a topic when MQTT config entry is disabled."""
mqtt_mock.connected = True
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
assert mqtt_config_entry.state is ConfigEntryState.LOADED
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
await hass.config_entries.async_set_disabled_by(
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
)
mqtt_mock.connected = False
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
async def test_subscribe_and_resubscribe(
@ -3824,7 +3845,7 @@ async def test_unload_config_entry(
async def test_publish_or_subscribe_without_valid_config_entry(
hass: HomeAssistant, record_calls: MessageCallbackType
) -> None:
"""Test internal publish function with bas use cases."""
"""Test internal publish function with bad use cases."""
with pytest.raises(HomeAssistantError):
await mqtt.async_publish(
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None