core/homeassistant/components/mqtt/__init__.py

1404 lines
46 KiB
Python

"""Support for MQTT message handling."""
from __future__ import annotations
from ast import literal_eval
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
import datetime as dt
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
import logging
from operator import attrgetter
import ssl
import time
from typing import TYPE_CHECKING, Any, Union, cast
import uuid
import attr
import certifi
import jinja2
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components import websocket_api
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_NAME,
CONF_CLIENT_ID,
CONF_DISCOVERY,
CONF_PASSWORD,
CONF_PAYLOAD,
CONF_PORT,
CONF_PROTOCOL,
CONF_USERNAME,
CONF_VALUE_TEMPLATE,
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
SERVICE_RELOAD,
Platform,
)
from homeassistant.core import (
CoreState,
Event,
HassJob,
HomeAssistant,
ServiceCall,
callback,
)
from homeassistant.data_entry_flow import BaseServiceInfo
from homeassistant.exceptions import HomeAssistantError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception
# Loading the config flow file will register the flow
from . import debug_info, discovery
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_CERTIFICATE,
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
CONF_COMMAND_TOPIC,
CONF_ENCODING,
CONF_QOS,
CONF_RETAIN,
CONF_STATE_TOPIC,
CONF_TLS_INSECURE,
CONF_TLS_VERSION,
CONF_TOPIC,
CONF_WILL_MESSAGE,
CONFIG_ENTRY_IS_SETUP,
DATA_CONFIG_ENTRY_LOCK,
DATA_MQTT_CONFIG,
DATA_MQTT_RELOAD_NEEDED,
DEFAULT_BIRTH,
DEFAULT_DISCOVERY,
DEFAULT_ENCODING,
DEFAULT_PREFIX,
DEFAULT_QOS,
DEFAULT_RETAIN,
DEFAULT_WILL,
DOMAIN,
MQTT_CONNECTED,
MQTT_DISCONNECTED,
PROTOCOL_31,
PROTOCOL_311,
)
from .discovery import LAST_DISCOVERY
from .models import (
AsyncMessageCallbackType,
MessageCallbackType,
PublishMessage,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally
# because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt
_LOGGER = logging.getLogger(__name__)
_SENTINEL = object()
DATA_MQTT = "mqtt"
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"
CONF_DISCOVERY_PREFIX = "discovery_prefix"
CONF_KEEPALIVE = "keepalive"
DEFAULT_PORT = 1883
DEFAULT_KEEPALIVE = 60
DEFAULT_PROTOCOL = PROTOCOL_311
DEFAULT_TLS_PROTOCOL = "auto"
DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_PORT: DEFAULT_PORT,
CONF_TLS_VERSION: DEFAULT_TLS_PROTOCOL,
CONF_WILL_MESSAGE: DEFAULT_WILL,
}
MANDATORY_DEFAULT_VALUES = (CONF_PORT,)
ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"
MAX_RECONNECT_WAIT = 300 # seconds
CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
DISCOVERY_COOLDOWN = 2
TIMEOUT_ACK = 10
PLATFORMS = [
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.CAMERA,
Platform.CLIMATE,
Platform.COVER,
Platform.FAN,
Platform.HUMIDIFIER,
Platform.LIGHT,
Platform.LOCK,
Platform.NUMBER,
Platform.SELECT,
Platform.SCENE,
Platform.SENSOR,
Platform.SIREN,
Platform.SWITCH,
Platform.VACUUM,
]
CLIENT_KEY_AUTH_MSG = (
"client_key and client_cert must both be present in "
"the MQTT broker configuration"
)
MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
{
vol.Inclusive(ATTR_TOPIC, "topic_payload"): valid_publish_topic,
vol.Inclusive(ATTR_PAYLOAD, "topic_payload"): cv.string,
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
},
required=True,
)
PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{
vol.Optional(Platform.ALARM_CONTROL_PANEL.value): cv.ensure_list,
vol.Optional(Platform.BINARY_SENSOR.value): cv.ensure_list,
vol.Optional(Platform.BUTTON.value): cv.ensure_list,
vol.Optional(Platform.CAMERA.value): cv.ensure_list,
vol.Optional(Platform.CLIMATE.value): cv.ensure_list,
vol.Optional(Platform.COVER.value): cv.ensure_list,
vol.Optional(Platform.FAN.value): cv.ensure_list,
vol.Optional(Platform.HUMIDIFIER.value): cv.ensure_list,
vol.Optional(Platform.LIGHT.value): cv.ensure_list,
vol.Optional(Platform.LOCK.value): cv.ensure_list,
vol.Optional(Platform.SCENE.value): cv.ensure_list,
vol.Optional(Platform.SELECT.value): cv.ensure_list,
vol.Optional(Platform.SIREN.value): cv.ensure_list,
vol.Optional(Platform.SENSOR.value): cv.ensure_list,
vol.Optional(Platform.SWITCH.value): cv.ensure_list,
vol.Optional(Platform.VACUUM.value): cv.ensure_list,
vol.Optional(Platform.NUMBER.value): cv.ensure_list,
}
)
CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Coerce(int), vol.Range(min=15)
),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_TLS_VERSION): vol.Any("auto", "1.0", "1.1", "1.2"),
vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
cv.string, vol.In([PROTOCOL_31, PROTOCOL_311])
),
vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
}
)
DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_DISCOVERY,
CONF_PASSWORD,
CONF_PORT,
CONF_TLS_VERSION,
CONF_USERNAME,
CONF_WILL_MESSAGE,
]
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.All(
cv.deprecated(CONF_BIRTH_MESSAGE), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_BROKER), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_DISCOVERY), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_PASSWORD), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_PORT), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_TLS_VERSION), # Deprecated June 2020
cv.deprecated(CONF_USERNAME), # Deprecated in HA Core 2022.3
cv.deprecated(CONF_WILL_MESSAGE), # Deprecated in HA Core 2022.3
CONFIG_SCHEMA_BASE,
)
},
extra=vol.ALLOW_EXTRA,
)
SCHEMA_BASE = {
vol.Optional(CONF_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
}
MQTT_BASE_SCHEMA = vol.Schema(SCHEMA_BASE)
# Will be removed when all platforms support a modern platform schema
MQTT_BASE_PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(SCHEMA_BASE)
# Will be removed when all platforms support a modern platform schema
MQTT_RO_PLATFORM_SCHEMA = MQTT_BASE_PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_STATE_TOPIC): valid_subscribe_topic,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
}
)
# Will be removed when all platforms support a modern platform schema
MQTT_RW_PLATFORM_SCHEMA = MQTT_BASE_PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic,
vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic,
}
)
# Sensor type platforms subscribe to MQTT events
MQTT_RO_SCHEMA = MQTT_BASE_SCHEMA.extend(
{
vol.Required(CONF_STATE_TOPIC): valid_subscribe_topic,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
}
)
# Switch type platforms publish to MQTT and may subscribe
MQTT_RW_SCHEMA = MQTT_BASE_SCHEMA.extend(
{
vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic,
vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic,
}
)
# Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.All(
vol.Schema(
{
vol.Exclusive(ATTR_TOPIC, CONF_TOPIC): valid_publish_topic,
vol.Exclusive(ATTR_TOPIC_TEMPLATE, CONF_TOPIC): cv.string,
vol.Exclusive(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, CONF_PAYLOAD): cv.string,
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
},
required=True,
),
cv.has_at_least_one_key(ATTR_TOPIC, ATTR_TOPIC_TEMPLATE),
)
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
class MqttCommandTemplate:
"""Class for rendering MQTT payload with command templates."""
def __init__(
self,
command_template: template.Template | None,
*,
hass: HomeAssistant | None = None,
entity: Entity | None = None,
) -> None:
"""Instantiate a command template."""
self._attr_command_template = command_template
if command_template is None:
return
self._entity = entity
command_template.hass = hass
if entity:
command_template.hass = entity.hass
@callback
def async_render(
self,
value: PublishPayloadType = None,
variables: TemplateVarsType = None,
) -> PublishPayloadType:
"""Render or convert the command template with given value or variables."""
def _convert_outgoing_payload(
payload: PublishPayloadType,
) -> PublishPayloadType:
"""Ensure correct raw MQTT payload is passed as bytes for publishing."""
if isinstance(payload, str):
try:
native_object = literal_eval(payload)
if isinstance(native_object, bytes):
return native_object
except (ValueError, TypeError, SyntaxError, MemoryError):
pass
return payload
if self._attr_command_template is None:
return value
values = {"value": value}
if self._entity:
values[ATTR_ENTITY_ID] = self._entity.entity_id
values[ATTR_NAME] = self._entity.name
if variables is not None:
values.update(variables)
return _convert_outgoing_payload(
self._attr_command_template.async_render(values, parse_result=False)
)
class MqttValueTemplate:
"""Class for rendering MQTT value template with possible json values."""
def __init__(
self,
value_template: template.Template | None,
*,
hass: HomeAssistant | None = None,
entity: Entity | None = None,
config_attributes: TemplateVarsType = None,
) -> None:
"""Instantiate a value template."""
self._value_template = value_template
self._config_attributes = config_attributes
if value_template is None:
return
value_template.hass = hass
self._entity = entity
if entity:
value_template.hass = entity.hass
@callback
def async_render_with_possible_json_value(
self,
payload: ReceivePayloadType,
default: ReceivePayloadType | object = _SENTINEL,
variables: TemplateVarsType = None,
) -> ReceivePayloadType:
"""Render with possible json value or pass-though a received MQTT value."""
if self._value_template is None:
return payload
values: dict[str, Any] = {}
if variables is not None:
values.update(variables)
if self._config_attributes is not None:
values.update(self._config_attributes)
if self._entity:
values[ATTR_ENTITY_ID] = self._entity.entity_id
values[ATTR_NAME] = self._entity.name
if default == _SENTINEL:
return self._value_template.async_render_with_possible_json_value(
payload, variables=values
)
return self._value_template.async_render_with_possible_json_value(
payload, default, variables=values
)
@dataclass
class MqttServiceInfo(BaseServiceInfo):
"""Prepared info from mqtt entries."""
topic: str
payload: ReceivePayloadType
qos: int
retain: bool
subscribed_topic: str
timestamp: dt.datetime
def publish(
hass: HomeAssistant,
topic: str,
payload: PublishPayloadType,
qos: int | None = 0,
retain: bool | None = False,
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
hass.add_job(async_publish, hass, topic, payload, qos, retain, encoding)
async def async_publish(
hass: HomeAssistant,
topic: str,
payload: PublishPayloadType,
qos: int | None = 0,
retain: bool | None = False,
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
_LOGGER.error(
"Can't pass-through payload for publishing %s on %s with no encoding set, need 'bytes' got %s",
payload,
topic,
type(payload),
)
return
outgoing_payload = str(payload)
if encoding != DEFAULT_ENCODING:
# a string is encoded as utf-8 by default, other encoding requires bytes as payload
try:
outgoing_payload = outgoing_payload.encode(encoding)
except (AttributeError, LookupError, UnicodeEncodeError):
_LOGGER.error(
"Can't encode payload for publishing %s on %s with encoding %s",
payload,
topic,
encoding,
)
return
await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain)
AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Awaitable[None]
]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
def wrap_msg_callback(
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
)
wrapper_func = async_wrapper
else:
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = wrapper
return wrapper_func
@bind_hass
async def async_subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: AsyncMessageCallbackType
| MessageCallbackType
| DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str | None = "utf-8",
):
"""Subscribe to an MQTT topic.
Call the return value to unsubscribe.
"""
# Count callback parameters which don't have a default value
non_default = 0
if msg_callback:
non_default = sum(
p.default == inspect.Parameter.empty
for _, p in inspect.signature(msg_callback).parameters.items()
)
wrapped_msg_callback = msg_callback
# If we have 3 parameters with no default value, wrap the callback
if non_default == 3:
module = inspect.getmodule(msg_callback)
_LOGGER.warning(
"Signature of MQTT msg_callback '%s.%s' is deprecated",
module.__name__ if module else "<unknown>",
msg_callback.__name__,
)
wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackType, msg_callback)
)
async_remove = await hass.data[DATA_MQTT].async_subscribe(
topic,
catch_log_exception(
wrapped_msg_callback,
lambda msg: (
f"Exception in {msg_callback.__name__} when handling msg on "
f"'{msg.topic}': '{msg.payload}'"
),
),
qos,
encoding,
)
return async_remove
@bind_hass
def subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: MessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str = "utf-8",
) -> Callable[[], None]:
"""Subscribe to an MQTT topic."""
async_remove = asyncio.run_coroutine_threadsafe(
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
).result()
def remove():
"""Remove listener convert."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
async def _async_setup_discovery(
hass: HomeAssistant, conf: ConfigType, config_entry
) -> None:
"""Try to start the discovery of MQTT devices.
This method is a coroutine.
"""
await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Start the MQTT protocol service."""
conf: ConfigType | None = config.get(DOMAIN)
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
debug_info.initialize(hass)
if conf:
conf = dict(conf)
hass.data[DATA_MQTT_CONFIG] = conf
if not bool(hass.config_entries.async_entries(DOMAIN)):
# Create an import flow if the user has yaml configured entities etc.
# but no broker configuration. Note: The intention is not for this to
# import broker configuration from YAML because that has been deprecated.
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={},
)
)
return True
def _merge_basic_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Merge basic options in configuration.yaml config with config entry.
This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in DEPRECATED_CONFIG_KEYS:
if key in yaml_config and key not in entry_config:
entry_config[key] = yaml_config[key]
entry_updated = True
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True
if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)
def _merge_extended_config(entry, conf):
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
# Merge basic configuration, and add missing defaults for basic options
_merge_basic_config(hass, entry, hass.data.get(DATA_MQTT_CONFIG, {}))
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return False
# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data))
# User has configuration.yaml config, warn about config entry overrides
elif any(key in conf for key in entry.data):
shared_keys = conf.keys() & entry.data.keys()
override = {k: entry.data[k] for k in shared_keys}
if CONF_PASSWORD in override:
override[CONF_PASSWORD] = "********"
_LOGGER.warning(
"Deprecated configuration settings found in configuration.yaml. "
"These settings from your configuration entry will override: %s",
override,
)
# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
hass.data[DATA_MQTT] = MQTT(
hass,
entry,
conf,
)
await hass.data[DATA_MQTT].async_connect()
async def async_stop_mqtt(_event: Event):
"""Stop MQTT component."""
await hass.data[DATA_MQTT].async_disconnect()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""
msg_topic = call.data.get(ATTR_TOPIC)
msg_topic_template = call.data.get(ATTR_TOPIC_TEMPLATE)
payload = call.data.get(ATTR_PAYLOAD)
payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE)
qos: int = call.data[ATTR_QOS]
retain: bool = call.data[ATTR_RETAIN]
if msg_topic_template is not None:
try:
rendered_topic = template.Template(
msg_topic_template, hass
).async_render(parse_result=False)
msg_topic = valid_publish_topic(rendered_topic)
except (jinja2.TemplateError, TemplateError) as exc:
_LOGGER.error(
"Unable to publish: rendering topic template of %s "
"failed because %s",
msg_topic_template,
exc,
)
return
except vol.Invalid as err:
_LOGGER.error(
"Unable to publish: topic template '%s' produced an "
"invalid topic '%s' after rendering (%s)",
msg_topic_template,
rendered_topic,
err,
)
return
if payload_template is not None:
try:
payload = MqttCommandTemplate(
template.Template(payload_template), hass=hass
).async_render()
except (jinja2.TemplateError, TemplateError) as exc:
_LOGGER.error(
"Unable to publish to %s: rendering payload template of "
"%s failed because %s",
msg_topic,
payload_template,
exc,
)
return
await hass.data[DATA_MQTT].async_publish(msg_topic, payload, qos, retain)
hass.services.async_register(
DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
)
async def async_dump_service(call: ServiceCall) -> None:
"""Handle MQTT dump service calls."""
messages = []
@callback
def collect_msg(msg):
messages.append((msg.topic, msg.payload.replace("\n", "")))
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
def write_dump():
with open(hass.config.path("mqtt_dump.txt"), "wt", encoding="utf8") as fp:
for msg in messages:
fp.write(",".join(msg) + "\n")
async def finish_dump(_):
"""Write dump to file."""
unsub()
await hass.async_add_executor_job(write_dump)
event.async_call_later(hass, call.data["duration"], finish_dump)
hass.services.async_register(
DOMAIN,
SERVICE_DUMP,
async_dump_service,
schema=vol.Schema(
{
vol.Required("topic"): valid_subscribe_topic,
vol.Optional("duration", default=5): int,
}
),
)
# setup platforms and discovery
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
async def async_forward_entry_setup():
"""Forward the config entry setup to the platforms."""
async with hass.data[DATA_CONFIG_ENTRY_LOCK]:
for component in PLATFORMS:
config_entries_key = f"{component}.mqtt"
if config_entries_key not in hass.data[CONFIG_ENTRY_IS_SETUP]:
hass.data[CONFIG_ENTRY_IS_SETUP].add(config_entries_key)
await hass.config_entries.async_forward_entry_setup(
entry, component
)
hass.async_create_task(async_forward_entry_setup())
if conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, conf, entry)
if DATA_MQTT_RELOAD_NEEDED in hass.data:
hass.data.pop(DATA_MQTT_RELOAD_NEEDED)
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=False,
)
return True
@attr.s(slots=True, frozen=True)
class Subscription:
"""Class to hold data about an active subscription."""
topic: str = attr.ib()
matcher: Any = attr.ib()
job: HassJob = attr.ib()
qos: int = attr.ib(default=0)
encoding: str | None = attr.ib(default="utf-8")
class MqttClientSetup:
"""Helper class to setup the paho mqtt client from config."""
def __init__(self, config: ConfigType) -> None:
"""Initialize the MQTT client setup helper."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if config[CONF_PROTOCOL] == PROTOCOL_31:
proto = mqtt.MQTTv31
else:
proto = mqtt.MQTTv311
if (client_id := config.get(CONF_CLIENT_ID)) is None:
# PAHO MQTT relies on the MQTT server to generate random client IDs.
# However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
self._client = mqtt.Client(client_id, protocol=proto)
# Enable logging
self._client.enable_logger()
username = config.get(CONF_USERNAME)
password = config.get(CONF_PASSWORD)
if username is not None:
self._client.username_pw_set(username, password)
if (certificate := config.get(CONF_CERTIFICATE)) == "auto":
certificate = certifi.where()
client_key = config.get(CONF_CLIENT_KEY)
client_cert = config.get(CONF_CLIENT_CERT)
tls_insecure = config.get(CONF_TLS_INSECURE)
if certificate is not None:
self._client.tls_set(
certificate,
certfile=client_cert,
keyfile=client_key,
tls_version=ssl.PROTOCOL_TLS,
)
if tls_insecure is not None:
self._client.tls_insecure_set(tls_insecure)
@property
def client(self) -> mqtt.Client:
"""Return the paho MQTT client."""
return self._client
class MQTT:
"""Home Assistant MQTT client."""
def __init__(
self,
hass: HomeAssistant,
config_entry,
conf,
) -> None:
"""Initialize Home Assistant MQTT client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self.hass = hass
self.config_entry = config_entry
self.conf = conf
self.subscriptions: list[Subscription] = []
self.connected = False
self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_operations: dict[str, asyncio.Event] = {}
if self.hass.state == CoreState.running:
self._ha_started.set()
else:
@callback
def ha_started(_):
self._ha_started.set()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated)
@staticmethod
async def async_config_entry_updated(
hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Handle signals of config entry being updated.
This is a static method because a class method (bound method), can not be used with weak references.
Causes for this is config entry options changing.
"""
self = hass.data[DATA_MQTT]
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data))
self.conf = _merge_extended_config(entry, conf)
await self.async_disconnect()
self.init_client()
await self.async_connect()
await discovery.async_stop(hass)
if self.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, self.conf, entry)
def init_client(self):
"""Initialize paho client."""
self._mqttc = MqttClientSetup(self.conf).client
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
self._mqttc.on_publish = self._mqtt_on_callback
self._mqttc.on_subscribe = self._mqtt_on_callback
self._mqttc.on_unsubscribe = self._mqtt_on_callback
if (
CONF_WILL_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
):
will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE])
else:
will_message = None
if will_message is not None:
self._mqttc.will_set(
topic=will_message.topic,
payload=will_message.payload,
qos=will_message.qos,
retain=will_message.retain,
)
async def async_publish(
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None:
"""Publish a MQTT message."""
async with self._paho_lock:
msg_info = await self.hass.async_add_executor_job(
self._mqttc.publish, topic, payload, qos, retain
)
_LOGGER.debug(
"Transmitting message on %s: '%s', mid: %s",
topic,
payload,
msg_info.mid,
)
_raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> None:
"""Connect to the host. Does not process messages yet."""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
result: int | None = None
try:
result = await self.hass.async_add_executor_job(
self._mqttc.connect,
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
self.conf[CONF_KEEPALIVE],
)
except OSError as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
if result is not None and result != 0:
_LOGGER.error(
"Failed to connect to MQTT server: %s", mqtt.error_string(result)
)
self._mqttc.loop_start()
async def async_disconnect(self):
"""Stop the MQTT client."""
def stop():
"""Stop the MQTT client."""
# Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop()
await self.hass.async_add_executor_job(stop)
async def async_subscribe(
self,
topic: str,
msg_callback: MessageCallbackType,
qos: int,
encoding: str | None = None,
) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos.
This method is a coroutine.
"""
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")
subscription = Subscription(
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
)
self.subscriptions.append(subscription)
self._matching_subscriptions.cache_clear()
# Only subscribe if currently connected.
if self.connected:
self._last_subscribe = time.time()
await self._async_perform_subscription(topic, qos)
@callback
def async_remove() -> None:
"""Remove subscription."""
if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription)
self._matching_subscriptions.cache_clear()
# Only unsubscribe if currently connected.
if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic))
return async_remove
async def _async_unsubscribe(self, topic: str) -> None:
"""Unsubscribe from a topic.
This method is a coroutine.
"""
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
async with self._paho_lock:
result: int | None = None
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topic
)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result)
await self._wait_for_mid(mid)
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
"""Perform a paho-mqtt subscription."""
async with self._paho_lock:
result: int | None = None
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, topic, qos
)
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
_raise_on_error(result)
await self._wait_for_mid(mid)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
"""On connect callback.
Resubscribe to all topics we were subscribed to and publish birth
message.
"""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
if result_code != mqtt.CONNACK_ACCEPTED:
_LOGGER.error(
"Unable to connect to the MQTT broker: %s",
mqtt.connack_string(result_code),
)
return
self.connected = True
dispatcher_send(self.hass, MQTT_CONNECTED)
_LOGGER.info(
"Connected to MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc):
# Re-subscribe with the highest requested qos
max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
if (
CONF_BIRTH_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
):
async def publish_birth_message(birth_message):
await self._ha_started.wait() # Wait for Home Assistant to start
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
await self.async_publish(
topic=birth_message.topic,
payload=birth_message.payload,
qos=birth_message.qos,
retain=birth_message.retain,
)
birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE])
asyncio.run_coroutine_threadsafe(
publish_birth_message(birth_message), self.hass.loop
)
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
"""Message received callback."""
self.hass.add_job(self._mqtt_handle_message, msg)
@lru_cache(2048)
def _matching_subscriptions(self, topic):
subscriptions = []
for subscription in self.subscriptions:
if subscription.matcher(topic):
subscriptions.append(subscription)
return subscriptions
@callback
def _mqtt_handle_message(self, msg) -> None:
_LOGGER.debug(
"Received message on %s%s: %s",
msg.topic,
" (retained)" if msg.retain else "",
msg.payload[0:8192],
)
timestamp = dt_util.utcnow()
subscriptions = self._matching_subscriptions(msg.topic)
for subscription in subscriptions:
payload: SubscribePayloadType = msg.payload
if subscription.encoding is not None:
try:
payload = msg.payload.decode(subscription.encoding)
except (AttributeError, UnicodeDecodeError):
_LOGGER.warning(
"Can't decode payload %s on %s with encoding %s (for %s)",
msg.payload[0:8192],
msg.topic,
subscription.encoding,
subscription.job,
)
continue
self.hass.async_run_hass_job(
subscription.job,
ReceiveMessage(
msg.topic,
payload,
msg.qos,
msg.retain,
subscription.topic,
timestamp,
),
)
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
"""Publish / Subscribe / Unsubscribe callback."""
self.hass.add_job(self._mqtt_handle_mid, mid)
@callback
def _mqtt_handle_mid(self, mid) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
self._pending_operations[mid].set()
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED)
_LOGGER.warning(
"Disconnected from MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
async def _wait_for_mid(self, mid):
"""Wait for ACK from broker."""
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
try:
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
except asyncio.TimeoutError:
_LOGGER.warning(
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
)
finally:
del self._pending_operations[mid]
async def _discovery_cooldown(self):
now = time.time()
# Reset discovery and subscribe cooldowns
self.hass.data[LAST_DISCOVERY] = now
self._last_subscribe = now
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
while now < wait_until:
await asyncio.sleep(wait_until - now)
now = time.time()
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
def _raise_on_error(result_code: int | None) -> None:
"""Raise error if error result."""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
if result_code is not None and result_code != 0:
raise HomeAssistantError(
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
)
def _matcher_for_topic(subscription: str) -> Any:
# pylint: disable-next=import-outside-toplevel
from paho.mqtt.matcher import MQTTMatcher
matcher = MQTTMatcher()
matcher[subscription] = True
return lambda topic: next(matcher.iter_match(topic), False)
@websocket_api.websocket_command(
{vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str}
)
@callback
def websocket_mqtt_info(hass, connection, msg):
"""Get MQTT debug info for device."""
device_id = msg["device_id"]
mqtt_info = debug_info.info_for_device(hass, device_id)
connection.send_result(msg["id"], mqtt_info)
@websocket_api.websocket_command(
{
vol.Required("type"): "mqtt/subscribe",
vol.Required("topic"): valid_subscribe_topic,
}
)
@websocket_api.async_response
async def websocket_subscribe(hass, connection, msg):
"""Subscribe to a MQTT topic."""
if not connection.user.is_admin:
raise Unauthorized
async def forward_messages(mqttmsg: ReceiveMessage):
"""Forward events to websocket."""
try:
payload = cast(bytes, mqttmsg.payload).decode(
DEFAULT_ENCODING
) # not str because encoding is set to None
except (AttributeError, UnicodeDecodeError):
# Convert non UTF-8 payload to a string presentation
payload = str(mqttmsg.payload)
connection.send_message(
websocket_api.event_message(
msg["id"],
{
"topic": mqttmsg.topic,
"payload": payload,
"qos": mqttmsg.qos,
"retain": mqttmsg.retain,
},
)
)
# Perform UTF-8 decoding directly in callback routine
connection.subscriptions[msg["id"]] = await async_subscribe(
hass, msg["topic"], forward_messages, encoding=None
)
connection.send_message(websocket_api.result_message(msg["id"]))
ConnectionStatusCallback = Callable[[bool], None]
@callback
def async_subscribe_connection_status(
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
"""Subscribe to MQTT connection changes."""
connection_status_callback_job = HassJob(connection_status_callback)
async def connected():
task = hass.async_run_hass_job(connection_status_callback_job, True)
if task:
await task
async def disconnected():
task = hass.async_run_hass_job(connection_status_callback_job, False)
if task:
await task
subscriptions = {
"connect": async_dispatcher_connect(hass, MQTT_CONNECTED, connected),
"disconnect": async_dispatcher_connect(hass, MQTT_DISCONNECTED, disconnected),
}
@callback
def unsubscribe():
subscriptions["connect"]()
subscriptions["disconnect"]()
return unsubscribe
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
return hass.data[DATA_MQTT].connected
async def async_remove_config_entry_device(
hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry
) -> bool:
"""Remove MQTT config entry from a device."""
# pylint: disable-next=import-outside-toplevel
from . import device_automation
await device_automation.async_removed_from_device(hass, device_entry.id)
return True