"""Support for MQTT message handling.""" from __future__ import annotations import asyncio from collections.abc import Callable from datetime import datetime import logging from typing import Any, cast import jinja2 import voluptuous as vol from homeassistant import config as conf_util from homeassistant.components import websocket_api from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_CLIENT_ID, CONF_DISCOVERY, CONF_PASSWORD, CONF_PAYLOAD, CONF_PORT, CONF_PROTOCOL, CONF_USERNAME, SERVICE_RELOAD, ) from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.helpers.reload import async_integration_yaml_config from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import ConfigType # Loading the config flow file will register the flow from . import debug_info, discovery from .client import ( # noqa: F401 MQTT, async_publish, async_subscribe, publish, subscribe, ) from .config_integration import PLATFORM_CONFIG_SCHEMA_BASE from .const import ( # noqa: F401 ATTR_PAYLOAD, ATTR_QOS, ATTR_RETAIN, ATTR_TOPIC, CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_CERTIFICATE, CONF_CLIENT_CERT, CONF_CLIENT_KEY, CONF_COMMAND_TOPIC, CONF_DISCOVERY_PREFIX, CONF_KEEPALIVE, CONF_QOS, CONF_STATE_TOPIC, CONF_TLS_INSECURE, CONF_TOPIC, CONF_TRANSPORT, CONF_WILL_MESSAGE, CONF_WS_HEADERS, CONF_WS_PATH, DATA_MQTT, DATA_MQTT_AVAILABLE, DEFAULT_DISCOVERY, DEFAULT_ENCODING, DEFAULT_PREFIX, DEFAULT_QOS, DEFAULT_RETAIN, DOMAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, PLATFORMS, RELOADABLE_PLATFORMS, ) from .models import ( # noqa: F401 MqttCommandTemplate, MqttData, MqttValueTemplate, PublishPayloadType, ReceiveMessage, ReceivePayloadType, ) from .util import ( # noqa: F401 async_create_certificate_temp_files, async_wait_for_mqtt_client, get_mqtt_data, mqtt_config_entry_enabled, valid_publish_topic, valid_qos_schema, valid_subscribe_topic, ) _LOGGER = logging.getLogger(__name__) SERVICE_PUBLISH = "publish" SERVICE_DUMP = "dump" ATTR_TOPIC_TEMPLATE = "topic_template" ATTR_PAYLOAD_TEMPLATE = "payload_template" MAX_RECONNECT_WAIT = 300 # seconds CONNECTION_SUCCESS = "connection_success" CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" CONFIG_ENTRY_CONFIG_KEYS = [ CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_CERTIFICATE, CONF_CLIENT_ID, CONF_CLIENT_CERT, CONF_CLIENT_KEY, CONF_DISCOVERY, CONF_DISCOVERY_PREFIX, CONF_KEEPALIVE, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_TLS_INSECURE, CONF_TRANSPORT, CONF_WS_PATH, CONF_WS_HEADERS, CONF_USERNAME, CONF_WILL_MESSAGE, ] CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.All( cv.removed(CONF_BIRTH_MESSAGE), # Removed in HA Core 2023.4 cv.removed(CONF_BROKER), # Removed in HA Core 2023.4 cv.removed(CONF_CERTIFICATE), # Removed in HA Core 2023.4 cv.removed(CONF_CLIENT_ID), # Removed in HA Core 2023.4 cv.removed(CONF_CLIENT_CERT), # Removed in HA Core 2023.4 cv.removed(CONF_CLIENT_KEY), # Removed in HA Core 2023.4 cv.removed(CONF_DISCOVERY), # Removed in HA Core 2022.3 cv.removed(CONF_DISCOVERY_PREFIX), # Removed in HA Core 2023.4 cv.removed(CONF_KEEPALIVE), # Removed in HA Core 2023.4 cv.removed(CONF_PASSWORD), # Removed in HA Core 2023.4 cv.removed(CONF_PORT), # Removed in HA Core 2023.4 cv.removed(CONF_PROTOCOL), # Removed in HA Core 2023.4 cv.removed(CONF_TLS_INSECURE), # Removed in HA Core 2023.4 cv.removed(CONF_USERNAME), # Removed in HA Core 2023.4 cv.removed(CONF_WILL_MESSAGE), # Removed in HA Core 2023.4 PLATFORM_CONFIG_SCHEMA_BASE, ) }, extra=vol.ALLOW_EXTRA, ) # Service call validation schema MQTT_PUBLISH_SCHEMA = vol.All( vol.Schema( { vol.Exclusive(ATTR_TOPIC, CONF_TOPIC): valid_publish_topic, vol.Exclusive(ATTR_TOPIC_TEMPLATE, CONF_TOPIC): cv.string, vol.Exclusive(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string, vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, CONF_PAYLOAD): cv.string, vol.Optional(ATTR_QOS, default=DEFAULT_QOS): valid_qos_schema, vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean, }, required=True, ), cv.has_at_least_one_key(ATTR_TOPIC, ATTR_TOPIC_TEMPLATE), ) async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None: """Handle signals of config entry being updated. Causes for this is config entry options changing. """ await hass.config_entries.async_reload(entry.entry_id) async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Load a config entry.""" conf: dict[str, Any] mqtt_data: MqttData async def _setup_client() -> tuple[MqttData, dict[str, Any]]: """Set up the MQTT client.""" # Fetch configuration conf = dict(entry.data) hass_config = await conf_util.async_hass_config_yaml(hass) mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) await async_create_certificate_temp_files(hass, conf) client = MQTT(hass, entry, conf) if DOMAIN in hass.data: mqtt_data = get_mqtt_data(hass) mqtt_data.config = mqtt_yaml mqtt_data.client = client else: # Initial setup websocket_api.async_register_command(hass, websocket_subscribe) websocket_api.async_register_command(hass, websocket_mqtt_info) hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client) client.start(mqtt_data) # Restore saved subscriptions if mqtt_data.subscriptions_to_restore: mqtt_data.client.async_restore_tracked_subscriptions( mqtt_data.subscriptions_to_restore ) mqtt_data.subscriptions_to_restore = [] mqtt_data.reload_dispatchers.append( entry.add_update_listener(_async_config_entry_updated) ) await mqtt_data.client.async_connect() return (mqtt_data, conf) client_available: asyncio.Future[bool] if DATA_MQTT_AVAILABLE not in hass.data: client_available = hass.data[DATA_MQTT_AVAILABLE] = asyncio.Future() else: client_available = hass.data[DATA_MQTT_AVAILABLE] setup_ok: bool = False try: mqtt_data, conf = await _setup_client() setup_ok = True finally: if not client_available.done(): client_available.set_result(setup_ok) async def async_publish_service(call: ServiceCall) -> None: """Handle MQTT publish service calls.""" msg_topic: str | None = call.data.get(ATTR_TOPIC) msg_topic_template: str | None = call.data.get(ATTR_TOPIC_TEMPLATE) payload: PublishPayloadType = call.data.get(ATTR_PAYLOAD) payload_template: str | None = call.data.get(ATTR_PAYLOAD_TEMPLATE) qos: int = call.data[ATTR_QOS] retain: bool = call.data[ATTR_RETAIN] if msg_topic_template is not None: try: rendered_topic: Any = template.Template( msg_topic_template, hass ).async_render(parse_result=False) msg_topic = valid_publish_topic(rendered_topic) except (jinja2.TemplateError, TemplateError) as exc: _LOGGER.error( ( "Unable to publish: rendering topic template of %s " "failed because %s" ), msg_topic_template, exc, ) return except vol.Invalid as err: _LOGGER.error( ( "Unable to publish: topic template '%s' produced an " "invalid topic '%s' after rendering (%s)" ), msg_topic_template, rendered_topic, err, ) return if payload_template is not None: try: payload = MqttCommandTemplate( template.Template(payload_template), hass=hass ).async_render() except (jinja2.TemplateError, TemplateError) as exc: _LOGGER.error( ( "Unable to publish to %s: rendering payload template of " "%s failed because %s" ), msg_topic, payload_template, exc, ) return assert msg_topic is not None await mqtt_data.client.async_publish(msg_topic, payload, qos, retain) hass.services.async_register( DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA ) async def async_dump_service(call: ServiceCall) -> None: """Handle MQTT dump service calls.""" messages: list[tuple[str, str]] = [] @callback def collect_msg(msg: ReceiveMessage) -> None: messages.append((msg.topic, str(msg.payload).replace("\n", ""))) unsub = await async_subscribe(hass, call.data["topic"], collect_msg) def write_dump() -> None: with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp: for msg in messages: fp.write(",".join(msg) + "\n") async def finish_dump(_: datetime) -> None: """Write dump to file.""" unsub() await hass.async_add_executor_job(write_dump) event.async_call_later(hass, call.data["duration"], finish_dump) hass.services.async_register( DOMAIN, SERVICE_DUMP, async_dump_service, schema=vol.Schema( { vol.Required("topic"): valid_subscribe_topic, vol.Optional("duration", default=5): int, } ), ) # setup platforms and discovery async def async_setup_reload_service() -> None: """Create the reload service for the MQTT domain.""" if hass.services.has_service(DOMAIN, SERVICE_RELOAD): return async def _reload_config(call: ServiceCall) -> None: """Reload the platforms.""" # Fetch updated manual configured items and validate config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {} mqtt_data.updated_config = config_yaml.get(DOMAIN, {}) # Reload the modern yaml platforms mqtt_platforms = async_get_platforms(hass, DOMAIN) tasks = [ entity.async_remove() for mqtt_platform in mqtt_platforms for entity in mqtt_platform.entities.values() # pylint: disable-next=protected-access if not entity._discovery_data # type: ignore[attr-defined] if mqtt_platform.config_entry and mqtt_platform.domain in RELOADABLE_PLATFORMS ] await asyncio.gather(*tasks) await asyncio.gather( *( [ mqtt_data.reload_handlers[component]() for component in RELOADABLE_PLATFORMS if component in mqtt_data.reload_handlers ] ) ) # Fire event hass.bus.async_fire(f"event_{DOMAIN}_reloaded", context=call.context) async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config) async def async_forward_entry_setup_and_setup_discovery( config_entry: ConfigEntry, conf: ConfigType, ) -> None: """Forward the config entry setup to the platforms and set up discovery.""" # Local import to avoid circular dependencies # pylint: disable-next=import-outside-toplevel from . import device_automation, tag # Forward the entry setup to the MQTT platforms await asyncio.gather( *( [ device_automation.async_setup_entry(hass, config_entry), tag.async_setup_entry(hass, config_entry), ] + [ hass.config_entries.async_forward_entry_setup(entry, component) for component in PLATFORMS ] ) ) # Setup discovery if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY): await discovery.async_start( hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry ) # Setup reload service after all platforms have loaded await async_setup_reload_service() await async_forward_entry_setup_and_setup_discovery(entry, conf) return True @websocket_api.websocket_command( {vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str} ) @callback def websocket_mqtt_info( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Get MQTT debug info for device.""" device_id = msg["device_id"] mqtt_info = debug_info.info_for_device(hass, device_id) connection.send_result(msg["id"], mqtt_info) @websocket_api.websocket_command( { vol.Required("type"): "mqtt/subscribe", vol.Required("topic"): valid_subscribe_topic, vol.Optional("qos"): valid_qos_schema, } ) @websocket_api.async_response async def websocket_subscribe( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Subscribe to a MQTT topic.""" if not connection.user.is_admin: raise Unauthorized @callback def forward_messages(mqttmsg: ReceiveMessage) -> None: """Forward events to websocket.""" try: payload = cast(bytes, mqttmsg.payload).decode( DEFAULT_ENCODING ) # not str because encoding is set to None except (AttributeError, UnicodeDecodeError): # Convert non UTF-8 payload to a string presentation payload = str(mqttmsg.payload) connection.send_message( websocket_api.event_message( msg["id"], { "topic": mqttmsg.topic, "payload": payload, "qos": mqttmsg.qos, "retain": mqttmsg.retain, }, ) ) # Perform UTF-8 decoding directly in callback routine qos: int = msg["qos"] if "qos" in msg else DEFAULT_QOS connection.subscriptions[msg["id"]] = await async_subscribe( hass, msg["topic"], forward_messages, encoding=None, qos=qos ) connection.send_message(websocket_api.result_message(msg["id"])) ConnectionStatusCallback = Callable[[bool], None] @callback def async_subscribe_connection_status( hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback ) -> Callable[[], None]: """Subscribe to MQTT connection changes.""" connection_status_callback_job = HassJob(connection_status_callback) async def connected() -> None: task = hass.async_run_hass_job(connection_status_callback_job, True) if task: await task async def disconnected() -> None: task = hass.async_run_hass_job(connection_status_callback_job, False) if task: await task subscriptions = { "connect": async_dispatcher_connect(hass, MQTT_CONNECTED, connected), "disconnect": async_dispatcher_connect(hass, MQTT_DISCONNECTED, disconnected), } @callback def unsubscribe() -> None: subscriptions["connect"]() subscriptions["disconnect"]() return unsubscribe def is_connected(hass: HomeAssistant) -> bool: """Return if MQTT client is connected.""" mqtt_data = get_mqtt_data(hass) return mqtt_data.client.connected async def async_remove_config_entry_device( hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry ) -> bool: """Remove MQTT config entry from a device.""" # pylint: disable-next=import-outside-toplevel from . import device_automation await device_automation.async_removed_from_device(hass, device_entry.id) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload MQTT dump and publish service when the config entry is unloaded.""" mqtt_data = get_mqtt_data(hass) mqtt_client = mqtt_data.client # Unload publish and dump services. hass.services.async_remove( DOMAIN, SERVICE_PUBLISH, ) hass.services.async_remove( DOMAIN, SERVICE_DUMP, ) # Stop the discovery await discovery.async_stop(hass) # Unload the platforms await asyncio.gather( *( hass.config_entries.async_forward_entry_unload(entry, component) for component in PLATFORMS ) ) await asyncio.sleep(0) # Unsubscribe reload dispatchers while reload_dispatchers := mqtt_data.reload_dispatchers: reload_dispatchers.pop()() # Cleanup listeners mqtt_client.cleanup() # Cleanup entity registry hooks registry_hooks = mqtt_data.discovery_registry_hooks while registry_hooks: registry_hooks.popitem()[1]() # Wait for all ACKs and stop the loop await mqtt_client.async_disconnect() # Cleanup MQTT client availability hass.data.pop(DATA_MQTT_AVAILABLE, None) # Store remaining subscriptions to be able to restore or reload them # when the entry is set up again if subscriptions := mqtt_client.subscriptions: mqtt_data.subscriptions_to_restore = subscriptions return True