Allow MQTT device based auto discovery (#109030)

* Add MQTT device based auto discovery

* Respect override of component options over shared ones

* Add state_topic, command_topic, qos and encoding as shared options

* Add shared option test

* Rename device.py to schemas.py

* Remove unused legacy `platform` attribute to avoid confusion

* Split validation device and origin info

* Require `origin` info on device based discovery

* Log origin info for only once for device discovery

* Fix tests and linters

* ruff

* speed up _replace_all_abbreviations

* Fix imports and merging errors - add slots attr

* Fix unrelated const changes

* More unrelated changes

* join string

* fix merge

* Undo move

* Adjust logger statement

* fix task storm to load platforms

* Revert "fix task storm to load platforms"

This reverts commit 8f12a5f251.

* bail if logging is disabled

* Correct mixup object_id and node_id

* Auto migrate entities to device discovery

* Add device discovery test for device_trigger

* Add migration support for non entity platforms

* Use helper to remove discovery payload

* Fix tests after update branch

* Add discovery migration test

* Refactor

* Repair after rebase

* Fix discovery is broken after migration

* Improve comments

* More comment improvements

* Split long lines

* Add comment to indicate payload dict can be empty

* typo

* Add walrus and update comment

* Add tag to migration test

* Join try blocks

* Refactor

* Cleanup not used attribute

* Refactor

* Move _replace_all_abbreviations out of try block

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
pull/118431/head
Jan Bouwhuis 2024-05-29 11:12:05 +02:00 committed by GitHub
parent 83e62c5239
commit 585892f067
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1109 additions and 174 deletions

View File

@ -33,6 +33,7 @@ ABBREVIATIONS = {
"cmd_on_tpl": "command_on_template", "cmd_on_tpl": "command_on_template",
"cmd_t": "command_topic", "cmd_t": "command_topic",
"cmd_tpl": "command_template", "cmd_tpl": "command_template",
"cmp": "components",
"cod_arm_req": "code_arm_required", "cod_arm_req": "code_arm_required",
"cod_dis_req": "code_disarm_required", "cod_dis_req": "code_disarm_required",
"cod_form": "code_format", "cod_form": "code_format",

View File

@ -86,6 +86,7 @@ CONF_TEMP_MIN = "min_temp"
CONF_CERTIFICATE = "certificate" CONF_CERTIFICATE = "certificate"
CONF_CLIENT_KEY = "client_key" CONF_CLIENT_KEY = "client_key"
CONF_CLIENT_CERT = "client_cert" CONF_CLIENT_CERT = "client_cert"
CONF_COMPONENTS = "components"
CONF_TLS_INSECURE = "tls_insecure" CONF_TLS_INSECURE = "tls_insecure"
# Device and integration info options # Device and integration info options

View File

@ -10,6 +10,8 @@ import re
import time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HassJobType, HomeAssistant, callback from homeassistant.core import HassJobType, HomeAssistant, callback
@ -19,7 +21,7 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
) )
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo, ReceivePayloadType
from homeassistant.helpers.typing import DiscoveryInfoType from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.loader import async_get_mqtt from homeassistant.loader import async_get_mqtt
from homeassistant.util.json import json_loads_object from homeassistant.util.json import json_loads_object
@ -32,15 +34,21 @@ from .const import (
ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_PAYLOAD,
ATTR_DISCOVERY_TOPIC, ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY, CONF_AVAILABILITY,
CONF_COMPONENTS,
CONF_ORIGIN, CONF_ORIGIN,
CONF_TOPIC, CONF_TOPIC,
DOMAIN, DOMAIN,
SUPPORTED_COMPONENTS, SUPPORTED_COMPONENTS,
) )
from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage from .models import DATA_MQTT, MqttComponentConfig, MqttOriginInfo, ReceiveMessage
from .schemas import MQTT_ORIGIN_INFO_SCHEMA from .schemas import DEVICE_DISCOVERY_SCHEMA, MQTT_ORIGIN_INFO_SCHEMA, SHARED_OPTIONS
from .util import async_forward_entry_setup_and_setup_discovery from .util import async_forward_entry_setup_and_setup_discovery
ABBREVIATIONS_SET = set(ABBREVIATIONS)
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
TOPIC_MATCHER = re.compile( TOPIC_MATCHER = re.compile(
@ -64,6 +72,7 @@ TOPIC_BASE = "~"
class MQTTDiscoveryPayload(dict[str, Any]): class MQTTDiscoveryPayload(dict[str, Any]):
"""Class to hold and MQTT discovery payload and discovery data.""" """Class to hold and MQTT discovery payload and discovery data."""
device_discovery: bool = False
discovery_data: DiscoveryInfoType discovery_data: DiscoveryInfoType
@ -82,6 +91,13 @@ def async_log_discovery_origin_info(
message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO
) -> None: ) -> None:
"""Log information about the discovery and origin.""" """Log information about the discovery and origin."""
# We only log origin info once per device discovery
if not _LOGGER.isEnabledFor(level):
# bail early if logging is disabled
return
if discovery_payload.device_discovery:
_LOGGER.log(level, message)
return
if CONF_ORIGIN not in discovery_payload: if CONF_ORIGIN not in discovery_payload:
_LOGGER.log(level, message) _LOGGER.log(level, message)
return return
@ -102,6 +118,151 @@ def async_log_discovery_origin_info(
) )
@callback
def _replace_abbreviations(
payload: Any | dict[str, Any],
abbreviations: dict[str, str],
abbreviations_set: set[str],
) -> None:
"""Replace abbreviations in an MQTT discovery payload."""
if not isinstance(payload, dict):
return
for key in abbreviations_set.intersection(payload):
payload[abbreviations[key]] = payload.pop(key)
@callback
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
"""Replace all abbreviations in an MQTT discovery payload."""
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
if CONF_ORIGIN in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_ORIGIN],
ORIGIN_ABBREVIATIONS,
ORIGIN_ABBREVIATIONS_SET,
)
if CONF_DEVICE in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_DEVICE],
DEVICE_ABBREVIATIONS,
DEVICE_ABBREVIATIONS_SET,
)
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
@callback
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
"""Replace topic base in MQTT discovery data."""
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
@callback
def _generate_device_cleanup_config(
hass: HomeAssistant, object_id: str, node_id: str | None
) -> dict[str, Any]:
"""Generate a cleanup message on device cleanup."""
mqtt_data = hass.data[DATA_MQTT]
device_node_id: str = f"{node_id} {object_id}" if node_id else object_id
config: dict[str, Any] = {CONF_DEVICE: {}, CONF_COMPONENTS: {}}
comp_config = config[CONF_COMPONENTS]
for platform, discover_id in mqtt_data.discovery_already_discovered:
ids = discover_id.split(" ")
component_node_id = ids.pop(0)
component_object_id = " ".join(ids)
if not ids:
continue
if device_node_id == component_node_id:
comp_config[component_object_id] = {CONF_PLATFORM: platform}
return config if comp_config else {}
@callback
def _parse_device_payload(
hass: HomeAssistant,
payload: ReceivePayloadType,
object_id: str,
node_id: str | None,
) -> dict[str, Any]:
"""Parse a device discovery payload."""
device_payload: dict[str, Any] = {}
if payload == "":
if not (
device_payload := _generate_device_cleanup_config(hass, object_id, node_id)
):
_LOGGER.warning(
"No device components to cleanup for %s, node_id '%s'",
object_id,
node_id,
)
return device_payload
try:
device_payload = MQTTDiscoveryPayload(json_loads_object(payload))
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return {}
_replace_all_abbreviations(device_payload)
try:
DEVICE_DISCOVERY_SCHEMA(device_payload)
except vol.Invalid as exc:
_LOGGER.warning(
"Invalid MQTT device discovery payload for %s, %s: '%s'",
object_id,
exc,
payload,
)
return {}
return device_payload
@callback
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
"""Parse and validate origin info from a single component discovery payload."""
if CONF_ORIGIN not in discovery_payload:
return True
try:
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception as exc: # noqa:BLE001
_LOGGER.warning(
"Unable to parse origin information from discovery message: %s, got %s",
exc,
discovery_payload[CONF_ORIGIN],
)
return False
return True
@callback
def _merge_common_options(
component_config: MQTTDiscoveryPayload, device_config: dict[str, Any]
) -> None:
"""Merge common options with the component config options."""
for option in SHARED_OPTIONS:
if option in device_config and option not in component_config:
component_config[option] = device_config.get(option)
async def async_start( # noqa: C901 async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None: ) -> None:
@ -145,8 +306,7 @@ async def async_start( # noqa: C901
_LOGGER.warning( _LOGGER.warning(
( (
"Received message on illegal discovery topic '%s'. The topic" "Received message on illegal discovery topic '%s'. The topic"
" contains " " contains not allowed characters. For more information see "
"not allowed characters. For more information see "
"https://www.home-assistant.io/integrations/mqtt/#discovery-topic" "https://www.home-assistant.io/integrations/mqtt/#discovery-topic"
), ),
topic, topic,
@ -155,88 +315,96 @@ async def async_start( # noqa: C901
component, node_id, object_id = match.groups() component, node_id, object_id = match.groups()
discovered_components: list[MqttComponentConfig] = []
if component == CONF_DEVICE:
# Process device based discovery message
# and regenate cleanup config.
device_discovery_payload = _parse_device_payload(
hass, payload, object_id, node_id
)
if not device_discovery_payload:
return
device_config: dict[str, Any]
origin_config: dict[str, Any] | None
component_configs: dict[str, dict[str, Any]]
device_config = device_discovery_payload[CONF_DEVICE]
origin_config = device_discovery_payload.get(CONF_ORIGIN)
component_configs = device_discovery_payload[CONF_COMPONENTS]
for component_id, config in component_configs.items():
component = config.pop(CONF_PLATFORM)
# The object_id in the device discovery topic is the unique identifier.
# It is used as node_id for the components it contains.
component_node_id = object_id
# The component_id in the discovery playload is used as object_id
# If we have an additional node_id in the discovery topic,
# we extend the component_id with it.
component_object_id = (
f"{node_id} {component_id}" if node_id else component_id
)
_replace_all_abbreviations(config)
# We add wrapper to the discovery payload with the discovery data.
# If the dict is empty after removing the platform, the payload is
# assumed to remove the existing config and we do not want to add
# device or orig or shared availability attributes.
if discovery_payload := MQTTDiscoveryPayload(config):
discovery_payload.device_discovery = True
discovery_payload[CONF_DEVICE] = device_config
discovery_payload[CONF_ORIGIN] = origin_config
# Only assign shared config options
# when they are not set at entity level
_merge_common_options(discovery_payload, device_discovery_payload)
discovered_components.append(
MqttComponentConfig(
component,
component_object_id,
component_node_id,
discovery_payload,
)
)
_LOGGER.debug(
"Process device discovery payload %s", device_discovery_payload
)
device_discovery_id = f"{node_id} {object_id}" if node_id else object_id
message = f"Processing device discovery for '{device_discovery_id}'"
async_log_discovery_origin_info(
message, MQTTDiscoveryPayload(device_discovery_payload)
)
else:
# Process component based discovery message
try:
discovery_payload = MQTTDiscoveryPayload(
json_loads_object(payload) if payload else {}
)
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
_replace_all_abbreviations(discovery_payload)
if not _valid_origin_info(discovery_payload):
return
discovered_components.append(
MqttComponentConfig(component, object_id, node_id, discovery_payload)
)
discovery_pending_discovered = mqtt_data.discovery_pending_discovered
for component_config in discovered_components:
component = component_config.component
node_id = component_config.node_id
object_id = component_config.object_id
discovery_payload = component_config.discovery_payload
if component not in SUPPORTED_COMPONENTS: if component not in SUPPORTED_COMPONENTS:
_LOGGER.warning("Integration %s is not supported", component) _LOGGER.warning("Integration %s is not supported", component)
return return
if payload:
try:
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
else:
discovery_payload = MQTTDiscoveryPayload({})
for key in list(discovery_payload):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
if CONF_DEVICE in discovery_payload:
device = discovery_payload[CONF_DEVICE]
for key in list(device):
abbreviated_key = key
key = DEVICE_ABBREVIATIONS.get(key, key)
device[key] = device.pop(abbreviated_key)
if CONF_ORIGIN in discovery_payload:
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
try:
for key in list(origin_info):
abbreviated_key = key
key = ORIGIN_ABBREVIATIONS.get(key, key)
origin_info[key] = origin_info.pop(abbreviated_key)
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception: # noqa: BLE001
_LOGGER.warning(
"Unable to parse origin information "
"from discovery message, got %s",
discovery_payload[CONF_ORIGIN],
)
return
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if isinstance(availability_conf, dict):
for key in list(availability_conf):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
availability_conf[key] = availability_conf.pop(abbreviated_key)
if TOPIC_BASE in discovery_payload: if TOPIC_BASE in discovery_payload:
base = discovery_payload.pop(TOPIC_BASE) _replace_topic_base(discovery_payload)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
# If present, the node_id will be included in the discovered object id # If present, the node_id will be included in the discovery_id.
discovery_id = f"{node_id} {object_id}" if node_id else object_id discovery_id = f"{node_id} {object_id}" if node_id else object_id
discovery_hash = (component, discovery_id) discovery_hash = (component, discovery_id)
if discovery_payload: if discovery_payload:
# Attach MQTT topic to the payload, used for debug prints # Attach MQTT topic to the payload, used for debug prints
setattr(
discovery_payload,
"__configuration_source__",
f"MQTT (topic: '{topic}')",
)
discovery_data = { discovery_data = {
ATTR_DISCOVERY_HASH: discovery_hash, ATTR_DISCOVERY_HASH: discovery_hash,
ATTR_DISCOVERY_PAYLOAD: discovery_payload, ATTR_DISCOVERY_PAYLOAD: discovery_payload,
@ -244,10 +412,8 @@ async def async_start( # noqa: C901
} }
setattr(discovery_payload, "discovery_data", discovery_data) setattr(discovery_payload, "discovery_data", discovery_data)
discovery_payload[CONF_PLATFORM] = "mqtt" if discovery_hash in discovery_pending_discovered:
pending = discovery_pending_discovered[discovery_hash]["pending"]
if discovery_hash in mqtt_data.discovery_pending_discovered:
pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
pending.appendleft(discovery_payload) pending.appendleft(discovery_payload)
_LOGGER.debug( _LOGGER.debug(
"Component has already been discovered: %s %s, queuing update", "Component has already been discovered: %s %s, queuing update",
@ -264,7 +430,7 @@ async def async_start( # noqa: C901
) -> None: ) -> None:
"""Process the payload of a new discovery.""" """Process the payload of a new discovery."""
_LOGGER.debug("Process discovery payload %s", payload) _LOGGER.debug("Process component discovery payload %s", payload)
discovery_hash = (component, discovery_id) discovery_hash = (component, discovery_id)
already_discovered = discovery_hash in mqtt_data.discovery_already_discovered already_discovered = discovery_hash in mqtt_data.discovery_already_discovered

View File

@ -682,6 +682,7 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
self._config_entry = config_entry self._config_entry = config_entry
self._config_entry_id = config_entry.entry_id self._config_entry_id = config_entry.entry_id
self._skip_device_removal: bool = False self._skip_device_removal: bool = False
self._migrate_discovery: str | None = None
discovery_hash = get_discovery_hash(discovery_data) discovery_hash = get_discovery_hash(discovery_data)
self._remove_discovery_updated = async_dispatcher_connect( self._remove_discovery_updated = async_dispatcher_connect(
@ -720,6 +721,24 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
discovery_hash, discovery_hash,
discovery_payload, discovery_payload,
) )
if not discovery_payload and self._migrate_discovery is not None:
# Ignore empty update from migrated and removed discovery config.
self._discovery_data[ATTR_DISCOVERY_TOPIC] = self._migrate_discovery
self._migrate_discovery = None
_LOGGER.info("Component successfully migrated: %s", discovery_hash)
send_discovery_done(self.hass, self._discovery_data)
return
if discovery_payload and (
(discovery_topic := discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC])
!= self._discovery_data[ATTR_DISCOVERY_TOPIC]
):
# Make sure the migrated discovery topic is removed.
self._migrate_discovery = discovery_topic
_LOGGER.debug("Migrating component: %s", discovery_hash)
self.hass.async_create_task(
async_remove_discovery_payload(self.hass, self._discovery_data)
)
if ( if (
discovery_payload discovery_payload
and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD] and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
@ -816,6 +835,7 @@ class MqttDiscoveryUpdateMixin(Entity):
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
self._migrate_discovery: str | None = None
if discovery_hash in self._registry_hooks: if discovery_hash in self._registry_hooks:
self._registry_hooks.pop(discovery_hash)() self._registry_hooks.pop(discovery_hash)()
@ -898,12 +918,27 @@ class MqttDiscoveryUpdateMixin(Entity):
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id) debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id)
if not payload: if not payload:
if self._migrate_discovery is not None:
# Ignore empty update of the migrated and removed discovery config.
self._discovery_data[ATTR_DISCOVERY_TOPIC] = self._migrate_discovery
self._migrate_discovery = None
_LOGGER.info("Component successfully migrated: %s", self.entity_id)
send_discovery_done(self.hass, self._discovery_data)
return
# Empty payload: Remove component # Empty payload: Remove component
_LOGGER.info("Removing component: %s", self.entity_id) _LOGGER.info("Removing component: %s", self.entity_id)
self.hass.async_create_task( self.hass.async_create_task(
self._async_process_discovery_update_and_remove() self._async_process_discovery_update_and_remove()
) )
elif self._discovery_update: elif self._discovery_update:
discovery_topic = payload.discovery_data[ATTR_DISCOVERY_TOPIC]
if discovery_topic != self._discovery_data[ATTR_DISCOVERY_TOPIC]:
# Make sure the migrated discovery topic is removed.
self._migrate_discovery = discovery_topic
_LOGGER.debug("Migrating component: %s", self.entity_id)
self.hass.async_create_task(
async_remove_discovery_payload(self.hass, self._discovery_data)
)
if old_payload != payload: if old_payload != payload:
# Non-empty, changed payload: Notify component # Non-empty, changed payload: Notify component
_LOGGER.info("Updating component: %s", self.entity_id) _LOGGER.info("Updating component: %s", self.entity_id)

View File

@ -424,5 +424,15 @@ class MqttData:
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict) tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
@dataclass(slots=True)
class MqttComponentConfig:
"""(component, object_id, node_id, discovery_payload)."""
component: str
object_id: str
node_id: str | None
discovery_payload: MQTTDiscoveryPayload
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt") DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available") DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
@ -10,6 +12,7 @@ from homeassistant.const import (
CONF_ICON, CONF_ICON,
CONF_MODEL, CONF_MODEL,
CONF_NAME, CONF_NAME,
CONF_PLATFORM,
CONF_UNIQUE_ID, CONF_UNIQUE_ID,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
) )
@ -24,10 +27,13 @@ from .const import (
CONF_AVAILABILITY_MODE, CONF_AVAILABILITY_MODE,
CONF_AVAILABILITY_TEMPLATE, CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC, CONF_AVAILABILITY_TOPIC,
CONF_COMMAND_TOPIC,
CONF_COMPONENTS,
CONF_CONFIGURATION_URL, CONF_CONFIGURATION_URL,
CONF_CONNECTIONS, CONF_CONNECTIONS,
CONF_DEPRECATED_VIA_HUB, CONF_DEPRECATED_VIA_HUB,
CONF_ENABLED_BY_DEFAULT, CONF_ENABLED_BY_DEFAULT,
CONF_ENCODING,
CONF_HW_VERSION, CONF_HW_VERSION,
CONF_IDENTIFIERS, CONF_IDENTIFIERS,
CONF_JSON_ATTRS_TEMPLATE, CONF_JSON_ATTRS_TEMPLATE,
@ -37,7 +43,9 @@ from .const import (
CONF_ORIGIN, CONF_ORIGIN,
CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE,
CONF_QOS,
CONF_SERIAL_NUMBER, CONF_SERIAL_NUMBER,
CONF_STATE_TOPIC,
CONF_SUGGESTED_AREA, CONF_SUGGESTED_AREA,
CONF_SUPPORT_URL, CONF_SUPPORT_URL,
CONF_SW_VERSION, CONF_SW_VERSION,
@ -45,8 +53,33 @@ from .const import (
CONF_VIA_DEVICE, CONF_VIA_DEVICE,
DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE,
SUPPORTED_COMPONENTS,
)
from .util import valid_publish_topic, valid_qos_schema, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
# Device discovery options that are also available at entity component level
SHARED_OPTIONS = [
CONF_AVAILABILITY,
CONF_AVAILABILITY_MODE,
CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC,
CONF_COMMAND_TOPIC,
CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE,
CONF_STATE_TOPIC,
]
MQTT_ORIGIN_INFO_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(CONF_NAME): cv.string,
vol.Optional(CONF_SW_VERSION): cv.string,
vol.Optional(CONF_SUPPORT_URL): cv.configuration_url,
}
),
) )
from .util import valid_subscribe_topic
MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema( MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
{ {
@ -148,3 +181,19 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string,
} }
) )
COMPONENT_CONFIG_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): vol.In(SUPPORTED_COMPONENTS)}
).extend({}, extra=True)
DEVICE_DISCOVERY_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
{
vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA,
vol.Required(CONF_COMPONENTS): vol.Schema({str: COMPONENT_CONFIG_SCHEMA}),
vol.Required(CONF_ORIGIN): MQTT_ORIGIN_INFO_SCHEMA,
vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic,
vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic,
vol.Optional(CONF_QOS): valid_qos_schema,
vol.Optional(CONF_ENCODING): cv.string,
}
)

View File

@ -2,7 +2,7 @@
from collections.abc import Generator from collections.abc import Generator
from random import getrandbits from random import getrandbits
from unittest.mock import patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -29,3 +29,10 @@ def mock_temp_dir(temp_dir_prefix: str) -> Generator[None, None, str]:
f"home-assistant-mqtt-{temp_dir_prefix}-{getrandbits(10):03x}", f"home-assistant-mqtt-{temp_dir_prefix}-{getrandbits(10):03x}",
) as mocked_temp_dir: ) as mocked_temp_dir:
yield mocked_temp_dir yield mocked_temp_dir
@pytest.fixture
def tag_mock() -> Generator[AsyncMock, None, None]:
"""Fixture to mock tag."""
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
yield mock_tag

View File

@ -35,22 +35,42 @@ def calls(hass: HomeAssistant) -> list[ServiceCall]:
return async_mock_service(hass, "test", "automation") return async_mock_service(hass, "test", "automation")
async def test_get_triggers( @pytest.mark.parametrize(
hass: HomeAssistant, ("discovery_topic", "data"),
device_registry: dr.DeviceRegistry, [
mqtt_mock_entry: MqttMockHAClientGenerator, (
) -> None: "homeassistant/device_automation/0AFFD2/bla/config",
"""Test we get the expected triggers from a discovered mqtt device."""
await mqtt_mock_entry()
data1 = (
'{ "automation_type":"trigger",' '{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},' ' "device":{"identifiers":["0AFFD2"]},'
' "payload": "short_press",' ' "payload": "short_press",'
' "topic": "foobar/triggers/button1",' ' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",' ' "type": "button_short_press",'
' "subtype": "button_1" }' ' "subtype": "button_1" }',
) ),
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1) (
"homeassistant/device/0AFFD2/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"}, "cmp": '
'{ "bla": {'
' "automation_type":"trigger", '
' "payload": "short_press",'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1",'
' "platform":"device_automation"}}}',
),
],
)
async def test_get_triggers(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
discovery_topic: str,
data: str,
) -> None:
"""Test we get the expected triggers from a discovered mqtt device."""
await mqtt_mock_entry()
async_fire_mqtt_message(hass, discovery_topic, data)
await hass.async_block_till_done() await hass.async_block_till_done()
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})

View File

@ -5,12 +5,14 @@ import copy
import json import json
from pathlib import Path from pathlib import Path
import re import re
from unittest.mock import AsyncMock, call, patch from typing import Any
from unittest.mock import ANY, AsyncMock, MagicMock, call, patch
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.components.mqtt.abbreviations import ( from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS, ABBREVIATIONS,
DEVICE_ABBREVIATIONS, DEVICE_ABBREVIATIONS,
@ -41,11 +43,13 @@ from homeassistant.setup import async_setup_component
from homeassistant.util.signal_type import SignalTypeFormat from homeassistant.util.signal_type import SignalTypeFormat
from .test_common import help_all_subscribe_calls, help_test_unload_config_entry from .test_common import help_all_subscribe_calls, help_test_unload_config_entry
from .test_tag import DEFAULT_TAG_ID, DEFAULT_TAG_SCAN
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
async_capture_events, async_capture_events,
async_fire_mqtt_message, async_fire_mqtt_message,
async_get_device_automations,
mock_config_flow, mock_config_flow,
mock_platform, mock_platform,
) )
@ -85,6 +89,8 @@ async def test_subscribing_config_topic(
[ [
("homeassistant/binary_sensor/bla/not_config", False), ("homeassistant/binary_sensor/bla/not_config", False),
("homeassistant/binary_sensor/rörkrökare/config", True), ("homeassistant/binary_sensor/rörkrökare/config", True),
("homeassistant/device/bla/not_config", False),
("homeassistant/device/rörkrökare/config", True),
], ],
) )
async def test_invalid_topic( async def test_invalid_topic(
@ -113,10 +119,15 @@ async def test_invalid_topic(
caplog.clear() caplog.clear()
@pytest.mark.parametrize(
"discovery_topic",
["homeassistant/binary_sensor/bla/config", "homeassistant/device/bla/config"],
)
async def test_invalid_json( async def test_invalid_json(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
discovery_topic: str,
) -> None: ) -> None:
"""Test sending in invalid JSON.""" """Test sending in invalid JSON."""
await mqtt_mock_entry() await mqtt_mock_entry()
@ -125,9 +136,7 @@ async def test_invalid_json(
) as mock_dispatcher_send: ) as mock_dispatcher_send:
mock_dispatcher_send = AsyncMock(return_value=None) mock_dispatcher_send = AsyncMock(return_value=None)
async_fire_mqtt_message( async_fire_mqtt_message(hass, discovery_topic, "not json")
hass, "homeassistant/binary_sensor/bla/config", "not json"
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert "Unable to parse JSON" in caplog.text assert "Unable to parse JSON" in caplog.text
assert not mock_dispatcher_send.called assert not mock_dispatcher_send.called
@ -176,6 +185,43 @@ async def test_invalid_config(
assert "Error 'expected int for dictionary value @ data['qos']'" in caplog.text assert "Error 'expected int for dictionary value @ data['qos']'" in caplog.text
async def test_invalid_device_discovery_config(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test sending in JSON that violates the discovery schema if device or platform key is missing."""
await mqtt_mock_entry()
async_fire_mqtt_message(
hass,
"homeassistant/device/bla/config",
'{ "o": {"name": "foobar"}, "cmp": '
'{ "acp1": {"name": "abc", "state_topic": "home/alarm", '
'"command_topic": "home/alarm/set", '
'"platform":"alarm_control_panel"}}}',
)
await hass.async_block_till_done()
assert (
"Invalid MQTT device discovery payload for bla, "
"required key not provided @ data['device']" in caplog.text
)
caplog.clear()
async_fire_mqtt_message(
hass,
"homeassistant/device/bla/config",
'{ "o": {"name": "foobar"}, "dev": {"identifiers": ["ABDE03"]}, '
'"cmp": { "acp1": {"name": "abc", "state_topic": "home/alarm", '
'"command_topic": "home/alarm/set" }}}',
)
await hass.async_block_till_done()
assert (
"Invalid MQTT device discovery payload for bla, "
"required key not provided @ data['components']['acp1']['platform']"
in caplog.text
)
async def test_only_valid_components( async def test_only_valid_components(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
@ -221,17 +267,51 @@ async def test_correct_config_discovery(
assert ("binary_sensor", "bla") in hass.data["mqtt"].discovery_already_discovered assert ("binary_sensor", "bla") in hass.data["mqtt"].discovery_already_discovered
@pytest.mark.parametrize(
("discovery_topic", "payloads", "discovery_id"),
[
(
"homeassistant/binary_sensor/bla/config",
(
'{"name":"Beer","state_topic": "test-topic",'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
'{"name":"Milk","state_topic": "test-topic",'
'"o":{"name":"bla2mqtt","sw":"1.1",'
'"url":"https://bla2mqtt.example.com/support"},'
'"dev":{"identifiers":["bla"]}}',
),
"bla",
),
(
"homeassistant/device/bla/config",
(
'{"cmp":{"bin_sens1":{"platform":"binary_sensor",'
'"name":"Beer","state_topic": "test-topic"}},'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
'{"cmp":{"bin_sens1":{"platform":"binary_sensor",'
'"name":"Milk","state_topic": "test-topic"}},'
'"o":{"name":"bla2mqtt","sw":"1.1",'
'"url":"https://bla2mqtt.example.com/support"},'
'"dev":{"identifiers":["bla"]}}',
),
"bla bin_sens1",
),
],
)
async def test_discovery_integration_info( async def test_discovery_integration_info(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
discovery_topic: str,
payloads: tuple[str, str],
discovery_id: str,
) -> None: ) -> None:
"""Test logging discovery of new and updated items.""" """Test discovery of integration info."""
await mqtt_mock_entry() await mqtt_mock_entry()
async_fire_mqtt_message( async_fire_mqtt_message(
hass, hass,
"homeassistant/binary_sensor/bla/config", discovery_topic,
'{ "name": "Beer", "state_topic": "test-topic", "o": {"name": "bla2mqtt", "sw": "1.0" } }', payloads[0],
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -241,7 +321,10 @@ async def test_discovery_integration_info(
assert state.name == "Beer" assert state.name == "Beer"
assert ( assert (
"Found new component: binary_sensor bla from external application bla2mqtt, version: 1.0" "Processing device discovery for 'bla' from external "
"application bla2mqtt, version: 1.0"
in caplog.text
or f"Found new component: binary_sensor {discovery_id} from external application bla2mqtt, version: 1.0"
in caplog.text in caplog.text
) )
caplog.clear() caplog.clear()
@ -249,8 +332,8 @@ async def test_discovery_integration_info(
# Send an update and add support url # Send an update and add support url
async_fire_mqtt_message( async_fire_mqtt_message(
hass, hass,
"homeassistant/binary_sensor/bla/config", discovery_topic,
'{ "name": "Milk", "state_topic": "test-topic", "o": {"name": "bla2mqtt", "sw": "1.1", "url": "https://bla2mqtt.example.com/support" } }', payloads[1],
) )
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer") state = hass.states.get("binary_sensor.beer")
@ -259,31 +342,343 @@ async def test_discovery_integration_info(
assert state.name == "Milk" assert state.name == "Milk"
assert ( assert (
"Component has already been discovered: binary_sensor bla, sending update from external application bla2mqtt, version: 1.1, support URL: https://bla2mqtt.example.com/support" f"Component has already been discovered: binary_sensor {discovery_id}"
in caplog.text in caplog.text
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config_message", ("single_configs", "device_discovery_topic", "device_config"),
[ [
(
[
(
"homeassistant/device_automation/0AFFD2/bla1/config",
{
"device": {"identifiers": ["0AFFD2"]},
"automation_type": "trigger",
"payload": "short_press",
"topic": "foobar/triggers/button1",
"type": "button_short_press",
"subtype": "button_1",
},
),
(
"homeassistant/sensor/0AFFD2/bla2/config",
{
"device": {"identifiers": ["0AFFD2"]},
"state_topic": "foobar/sensors/bla2/state",
},
),
(
"homeassistant/tag/0AFFD2/bla3/config",
{
"device": {"identifiers": ["0AFFD2"]},
"topic": "foobar/tags/bla3/see",
},
),
],
"homeassistant/device/0AFFD2/config",
{
"device": {"identifiers": ["0AFFD2"]},
"o": {"name": "foobar"},
"cmp": {
"bla1": {
"platform": "device_automation",
"automation_type": "trigger",
"payload": "short_press",
"topic": "foobar/triggers/button1",
"type": "button_short_press",
"subtype": "button_1",
},
"bla2": {
"platform": "sensor",
"state_topic": "foobar/sensors/bla2/state",
},
"bla3": {
"platform": "tag",
"topic": "foobar/tags/bla3/see",
},
},
},
)
],
)
async def test_discovery_migration(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
tag_mock: AsyncMock,
single_configs: list[tuple[str, dict[str, Any]]],
device_discovery_topic: str,
device_config: dict[str, Any],
) -> None:
"""Test the migration of single discovery to device discovery."""
mock_mqtt = await mqtt_mock_entry()
publish_mock: MagicMock = mock_mqtt._mqttc.publish
# Discovery single config schema
for discovery_topic, config in single_configs:
payload = json.dumps(config)
async_fire_mqtt_message(
hass,
discovery_topic,
payload,
)
await hass.async_block_till_done()
await hass.async_block_till_done()
async def check_discovered_items():
# Check the device_trigger was discovered
device_entry = device_registry.async_get_device(
identifiers={("mqtt", "0AFFD2")}
)
assert device_entry is not None
triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device_entry.id
)
assert len(triggers) == 1
# Check the sensor was discovered
state = hass.states.get("sensor.mqtt_sensor")
assert state is not None
# Check the tag works
async_fire_mqtt_message(hass, "foobar/tags/bla3/see", DEFAULT_TAG_SCAN)
await hass.async_block_till_done()
tag_mock.assert_called_once_with(ANY, DEFAULT_TAG_ID, device_entry.id)
tag_mock.reset_mock()
await check_discovered_items()
# Migrate to device based discovery
payload = json.dumps(device_config)
async_fire_mqtt_message(
hass,
device_discovery_topic,
payload,
)
await hass.async_block_till_done()
# Test the single discovery topics are reset and `None` is published
await check_discovered_items()
assert len(publish_mock.mock_calls) == len(single_configs)
published_topics = {call[1][0] for call in publish_mock.mock_calls}
expected_topics = {item[0] for item in single_configs}
assert published_topics == expected_topics
published_payloads = [call[1][1] for call in publish_mock.mock_calls]
assert published_payloads == [None, None, None]
@pytest.mark.parametrize(
("discovery_topic", "payload", "discovery_id"),
[
(
"homeassistant/binary_sensor/bla/config",
'{"name":"Beer","state_topic": "test-topic",'
'"avty": {"topic": "avty-topic"},'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
"bla",
),
(
"homeassistant/device/bla/config",
'{"cmp":{"bin_sens1":{"platform":"binary_sensor",'
'"name":"Beer","state_topic": "test-topic"}},'
'"avty": {"topic": "avty-topic"},'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
"bin_sens1 bla",
),
],
)
async def test_discovery_availability(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
discovery_topic: str,
payload: str,
discovery_id: str,
) -> None:
"""Test device discovery with shared availability mapping."""
await mqtt_mock_entry()
async_fire_mqtt_message(
hass,
discovery_topic,
payload,
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.name == "Beer"
assert state.state == STATE_UNAVAILABLE
async_fire_mqtt_message(
hass,
"avty-topic",
"online",
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(
hass,
"test-topic",
"ON",
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.state == STATE_ON
@pytest.mark.parametrize(
("discovery_topic", "payload", "discovery_id"),
[
(
"homeassistant/device/bla/config",
'{"cmp":{"bin_sens1":{"platform":"binary_sensor",'
'"avty": {"topic": "avty-topic-component"},'
'"name":"Beer","state_topic": "test-topic"}},'
'"avty": {"topic": "avty-topic-device"},'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
"bin_sens1 bla",
),
(
"homeassistant/device/bla/config",
'{"cmp":{"bin_sens1":{"platform":"binary_sensor",'
'"availability_topic": "avty-topic-component",'
'"name":"Beer","state_topic": "test-topic"}},'
'"availability_topic": "avty-topic-device",'
'"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}',
"bin_sens1 bla",
),
],
)
async def test_discovery_component_availability_overridden(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
discovery_topic: str,
payload: str,
discovery_id: str,
) -> None:
"""Test device discovery with overridden shared availability mapping."""
await mqtt_mock_entry()
async_fire_mqtt_message(
hass,
discovery_topic,
payload,
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.name == "Beer"
assert state.state == STATE_UNAVAILABLE
async_fire_mqtt_message(
hass,
"avty-topic-device",
"online",
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.state == STATE_UNAVAILABLE
async_fire_mqtt_message(
hass,
"avty-topic-component",
"online",
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(
hass,
"test-topic",
"ON",
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert state.state == STATE_ON
@pytest.mark.parametrize(
("discovery_topic", "config_message", "error_message"),
[
(
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic", "o": "bla2mqtt" }', '{ "name": "Beer", "state_topic": "test-topic", "o": "bla2mqtt" }',
"Unable to parse origin information from discovery message",
),
(
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic", "o": 2.0 }', '{ "name": "Beer", "state_topic": "test-topic", "o": 2.0 }',
"Unable to parse origin information from discovery message",
),
(
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic", "o": null }', '{ "name": "Beer", "state_topic": "test-topic", "o": null }',
"Unable to parse origin information from discovery message",
),
(
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic", "o": {"sw": "bla2mqtt"} }', '{ "name": "Beer", "state_topic": "test-topic", "o": {"sw": "bla2mqtt"} }',
"Unable to parse origin information from discovery message",
),
(
"homeassistant/device/bla/config",
'{"dev":{"identifiers":["bs1"]},"cmp":{"bs1":'
'{"platform":"binary_sensor","name":"Beer","state_topic":"test-topic"}'
'},"o": "bla2mqtt"'
"}",
"Invalid MQTT device discovery payload for bla, "
"expected a dictionary for dictionary value @ data['origin']",
),
(
"homeassistant/device/bla/config",
'{"dev":{"identifiers":["bs1"]},"cmp":{"bs1":'
'{"platform":"binary_sensor","name":"Beer","state_topic":"test-topic"}'
'},"o": 2.0'
"}",
"Invalid MQTT device discovery payload for bla, "
"expected a dictionary for dictionary value @ data['origin']",
),
(
"homeassistant/device/bla/config",
'{"dev":{"identifiers":["bs1"]},"cmp":{"bs1":'
'{"platform":"binary_sensor","name":"Beer","state_topic":"test-topic"}'
'},"o": null'
"}",
"Invalid MQTT device discovery payload for bla, "
"expected a dictionary for dictionary value @ data['origin']",
),
(
"homeassistant/device/bla/config",
'{"dev":{"identifiers":["bs1"]},"cmp":{"bs1":'
'{"platform":"binary_sensor","name":"Beer","state_topic":"test-topic"}'
'},"o": {"sw": "bla2mqtt"}'
"}",
"Invalid MQTT device discovery payload for bla, "
"required key not provided @ data['origin']['name']",
),
], ],
) )
async def test_discovery_with_invalid_integration_info( async def test_discovery_with_invalid_integration_info(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
discovery_topic: str,
config_message: str, config_message: str,
error_message: str,
) -> None: ) -> None:
"""Test sending in correct JSON.""" """Test sending in correct JSON."""
await mqtt_mock_entry() await mqtt_mock_entry()
async_fire_mqtt_message( async_fire_mqtt_message(
hass, hass,
"homeassistant/binary_sensor/bla/config", discovery_topic,
config_message, config_message,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -291,9 +686,7 @@ async def test_discovery_with_invalid_integration_info(
state = hass.states.get("binary_sensor.beer") state = hass.states.get("binary_sensor.beer")
assert state is None assert state is None
assert ( assert error_message in caplog.text
"Unable to parse origin information from discovery message, got" in caplog.text
)
async def test_discover_fan( async def test_discover_fan(
@ -822,34 +1215,62 @@ async def test_duplicate_removal(
assert "Component has already been discovered: binary_sensor bla" not in caplog.text assert "Component has already been discovered: binary_sensor bla" not in caplog.text
@pytest.mark.parametrize(
("discovery_topic", "discovery_payload", "entity_ids"),
[
(
"homeassistant/sensor/bla/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }',
["sensor.none_mqtt_sensor"],
),
(
"homeassistant/device/bla/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "cmp": {"sens1": {'
' "platform": "sensor",'
' "name": "sensor1",'
' "state_topic": "foobar/sensor1",'
' "unique_id": "unique1"'
' },"sens2": {'
' "platform": "sensor",'
' "name": "sensor2",'
' "state_topic": "foobar/sensor2",'
' "unique_id": "unique2"'
"}}}",
["sensor.none_sensor1", "sensor.none_sensor2"],
),
],
)
async def test_cleanup_device( async def test_cleanup_device(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
discovery_topic: str,
discovery_payload: str,
entity_ids: list[str],
) -> None: ) -> None:
"""Test discovered device is cleaned up when entry removed from device.""" """Test discovered device is cleaned up when entry removed from device."""
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
assert await async_setup_component(hass, "config", {}) assert await async_setup_component(hass, "config", {})
ws_client = await hass_ws_client(hass) ws_client = await hass_ws_client(hass)
data = ( async_fire_mqtt_message(hass, discovery_topic, discovery_payload)
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done() await hass.async_block_till_done()
# Verify device and registry entries are created # Verify device and registry entries are created
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
assert device_entry is not None assert device_entry is not None
entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor")
for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is not None assert entity_entry is not None
state = hass.states.get("sensor.none_mqtt_sensor") state = hass.states.get(entity_id)
assert state is not None assert state is not None
# Remove MQTT from the device # Remove MQTT from the device
@ -868,60 +1289,221 @@ async def test_cleanup_device(
assert entity_entry is None assert entity_entry is None
# Verify state is removed # Verify state is removed
state = hass.states.get("sensor.none_mqtt_sensor") for entity_id in entity_ids:
state = hass.states.get(entity_id)
assert state is None assert state is None
await hass.async_block_till_done() await hass.async_block_till_done()
# Verify retained discovery topic has been cleared # Verify retained discovery topic has been cleared
mqtt_mock.async_publish.assert_called_once_with( mqtt_mock.async_publish.assert_called_with(discovery_topic, None, 0, True)
"homeassistant/sensor/bla/config", None, 0, True
)
@pytest.mark.parametrize(
("discovery_topic", "discovery_payload", "entity_ids"),
[
(
"homeassistant/sensor/bla/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }',
["sensor.none_mqtt_sensor"],
),
(
"homeassistant/device/bla/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "cmp": {"sens1": {'
' "platform": "sensor",'
' "name": "sensor1",'
' "state_topic": "foobar/sensor1",'
' "unique_id": "unique1"'
' },"sens2": {'
' "platform": "sensor",'
' "name": "sensor2",'
' "state_topic": "foobar/sensor2",'
' "unique_id": "unique2"'
"}}}",
["sensor.none_sensor1", "sensor.none_sensor2"],
),
],
)
async def test_cleanup_device_mqtt( async def test_cleanup_device_mqtt(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
discovery_topic: str,
discovery_payload: str,
entity_ids: list[str],
) -> None: ) -> None:
"""Test discvered device is cleaned up when removed through MQTT.""" """Test discovered device is cleaned up when removed through MQTT."""
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) # set up an existing sensor first
data = (
'{ "device":{"identifiers":["0AFFD3"]},'
' "name": "sensor_base",'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique_base" }'
)
base_discovery_topic = "homeassistant/sensor/bla_base/config"
base_entity_id = "sensor.none_sensor_base"
async_fire_mqtt_message(hass, base_discovery_topic, data)
await hass.async_block_till_done()
# Verify the base entity has been created and it has a state
base_device_entry = device_registry.async_get_device(
identifiers={("mqtt", "0AFFD3")}
)
assert base_device_entry is not None
entity_entry = entity_registry.async_get(base_entity_id)
assert entity_entry is not None
state = hass.states.get(base_entity_id)
assert state is not None
async_fire_mqtt_message(hass, discovery_topic, discovery_payload)
await hass.async_block_till_done() await hass.async_block_till_done()
# Verify device and registry entries are created # Verify device and registry entries are created
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
assert device_entry is not None assert device_entry is not None
entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor") for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is not None assert entity_entry is not None
state = hass.states.get("sensor.none_mqtt_sensor") state = hass.states.get(entity_id)
assert state is not None assert state is not None
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", "") async_fire_mqtt_message(hass, discovery_topic, "")
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
# Verify device and registry entries are cleared # Verify device and registry entries are cleared
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
assert device_entry is None assert device_entry is None
entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor")
for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is None assert entity_entry is None
# Verify state is removed # Verify state is removed
state = hass.states.get("sensor.none_mqtt_sensor") state = hass.states.get(entity_id)
assert state is None assert state is None
await hass.async_block_till_done() await hass.async_block_till_done()
# Verify retained discovery topics have not been cleared again # Verify retained discovery topics have not been cleared again
mqtt_mock.async_publish.assert_not_called() mqtt_mock.async_publish.assert_not_called()
# Verify the base entity still exists and it has a state
base_device_entry = device_registry.async_get_device(
identifiers={("mqtt", "0AFFD3")}
)
assert base_device_entry is not None
entity_entry = entity_registry.async_get(base_entity_id)
assert entity_entry is not None
state = hass.states.get(base_entity_id)
assert state is not None
async def test_cleanup_device_mqtt_device_discovery(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test discovered device is cleaned up partly when removed through MQTT."""
await mqtt_mock_entry()
discovery_topic = "homeassistant/device/bla/config"
discovery_payload = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "cmp": {"sens1": {'
' "platform": "sensor",'
' "name": "sensor1",'
' "state_topic": "foobar/sensor1",'
' "unique_id": "unique1"'
' },"sens2": {'
' "platform": "sensor",'
' "name": "sensor2",'
' "state_topic": "foobar/sensor2",'
' "unique_id": "unique2"'
"}}}"
)
entity_ids = ["sensor.none_sensor1", "sensor.none_sensor2"]
async_fire_mqtt_message(hass, discovery_topic, discovery_payload)
await hass.async_block_till_done()
# Verify device and registry entries are created
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
assert device_entry is not None
for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is not None
state = hass.states.get(entity_id)
assert state is not None
# Do update and remove sensor 2 from device
discovery_payload_update1 = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "cmp": {"sens1": {'
' "platform": "sensor",'
' "name": "sensor1",'
' "state_topic": "foobar/sensor1",'
' "unique_id": "unique1"'
' },"sens2": {'
' "platform": "sensor"'
"}}}"
)
async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update1)
await hass.async_block_till_done()
state = hass.states.get(entity_ids[0])
assert state is not None
state = hass.states.get(entity_ids[1])
assert state is None
# Repeating the update
async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update1)
await hass.async_block_till_done()
state = hass.states.get(entity_ids[0])
assert state is not None
state = hass.states.get(entity_ids[1])
assert state is None
# Removing last sensor
discovery_payload_update2 = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "cmp": {"sens1": {'
' "platform": "sensor"'
' },"sens2": {'
' "platform": "sensor"'
"}}}"
)
async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update2)
await hass.async_block_till_done()
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
# Verify the device entry was removed with the last sensor
assert device_entry is None
for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is None
state = hass.states.get(entity_id)
assert state is None
# Repeating the update
async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update2)
await hass.async_block_till_done()
# Clear the empty discovery payload and verify there was nothing to cleanup
async_fire_mqtt_message(hass, discovery_topic, "")
await hass.async_block_till_done()
assert "No device components to cleanup" in caplog.text
async def test_cleanup_device_multiple_config_entries( async def test_cleanup_device_multiple_config_entries(
hass: HomeAssistant, hass: HomeAssistant,
@ -1806,3 +2388,77 @@ async def test_discovery_dispatcher_signal_type_messages(
assert len(calls) == 1 assert len(calls) == 1
assert calls[0] == test_data assert calls[0] == test_data
unsub() unsub()
@pytest.mark.parametrize(
("discovery_topic", "discovery_payload", "entity_ids"),
[
(
"homeassistant/device/bla/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"},'
' "state_topic": "foobar/sensor-shared",'
' "cmp": {"sens1": {'
' "platform": "sensor",'
' "name": "sensor1",'
' "unique_id": "unique1"'
' },"sens2": {'
' "platform": "sensor",'
' "name": "sensor2",'
' "unique_id": "unique2"'
' },"sens3": {'
' "platform": "sensor",'
' "name": "sensor3",'
' "state_topic": "foobar/sensor3",'
' "unique_id": "unique3"'
"}}}",
["sensor.none_sensor1", "sensor.none_sensor2", "sensor.none_sensor3"],
),
],
)
async def test_shared_state_topic(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
discovery_topic: str,
discovery_payload: str,
entity_ids: list[str],
) -> None:
"""Test a shared state_topic can be used."""
await mqtt_mock_entry()
async_fire_mqtt_message(hass, discovery_topic, discovery_payload)
await hass.async_block_till_done()
# Verify device and registry entries are created
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
assert device_entry is not None
for entity_id in entity_ids:
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry is not None
state = hass.states.get(entity_id)
assert state is not None
assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(hass, "foobar/sensor-shared", "New state")
entity_id = entity_ids[0]
state = hass.states.get(entity_id)
assert state is not None
assert state.state == "New state"
entity_id = entity_ids[1]
state = hass.states.get(entity_id)
assert state is not None
assert state.state == "New state"
entity_id = entity_ids[2]
state = hass.states.get(entity_id)
assert state is not None
assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(hass, "foobar/sensor3", "New state3")
entity_id = entity_ids[2]
state = hass.states.get(entity_id)
assert state is not None
assert state.state == "New state3"

View File

@ -3162,7 +3162,6 @@ async def test_mqtt_ws_get_device_debug_info(
} }
data_sensor = json.dumps(config_sensor) data_sensor = json.dumps(config_sensor)
data_trigger = json.dumps(config_trigger) data_trigger = json.dumps(config_trigger)
config_sensor["platform"] = config_trigger["platform"] = mqtt.DOMAIN
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data_sensor) async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data_sensor)
async_fire_mqtt_message( async_fire_mqtt_message(
@ -3219,7 +3218,6 @@ async def test_mqtt_ws_get_device_debug_info_binary(
"unique_id": "unique", "unique_id": "unique",
} }
data = json.dumps(config) data = json.dumps(config)
config["platform"] = mqtt.DOMAIN
async_fire_mqtt_message(hass, "homeassistant/camera/bla/config", data) async_fire_mqtt_message(hass, "homeassistant/camera/bla/config", data)
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -1,9 +1,8 @@
"""The tests for MQTT tag scanner.""" """The tests for MQTT tag scanner."""
from collections.abc import Generator
import copy import copy
import json import json
from unittest.mock import ANY, AsyncMock, patch from unittest.mock import ANY, AsyncMock
import pytest import pytest
@ -46,13 +45,6 @@ DEFAULT_TAG_SCAN_JSON = (
) )
@pytest.fixture
def tag_mock() -> Generator[AsyncMock, None, None]:
"""Fixture to mock tag."""
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
yield mock_tag
@pytest.mark.no_fail_on_log_exception @pytest.mark.no_fail_on_log_exception
async def test_discover_bad_tag( async def test_discover_bad_tag(
hass: HomeAssistant, hass: HomeAssistant,