"""Provides device automations for MQTT.""" from __future__ import annotations from collections.abc import Callable import logging from typing import Any import attr import voluptuous as vol from homeassistant.components.automation import ( AutomationActionType, AutomationTriggerInfo, ) from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.const import ( CONF_DEVICE, CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE, CONF_VALUE_TEMPLATE, ) from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, ) from homeassistant.helpers.typing import ConfigType from . import debug_info, trigger as mqtt_trigger from .. import mqtt from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_TOPIC, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN, ) from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash from .mixins import ( CONF_CONNECTIONS, CONF_IDENTIFIERS, MQTT_ENTITY_DEVICE_INFO_SCHEMA, cleanup_device_registry, device_info_from_config, ) _LOGGER = logging.getLogger(__name__) CONF_AUTOMATION_TYPE = "automation_type" CONF_DISCOVERY_ID = "discovery_id" CONF_SUBTYPE = "subtype" DEFAULT_ENCODING = "utf-8" DEVICE = "device" MQTT_TRIGGER_BASE = { # Trigger when MQTT message is received CONF_PLATFORM: DEVICE, CONF_DOMAIN: DOMAIN, } TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): DEVICE, vol.Required(CONF_DOMAIN): DOMAIN, vol.Required(CONF_DEVICE_ID): str, vol.Required(CONF_DISCOVERY_ID): str, vol.Required(CONF_TYPE): cv.string, vol.Required(CONF_SUBTYPE): cv.string, } ) TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( { vol.Required(CONF_AUTOMATION_TYPE): str, vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, vol.Optional(CONF_PAYLOAD, default=None): vol.Any(None, cv.string), vol.Required(CONF_SUBTYPE): cv.string, vol.Required(CONF_TOPIC): cv.string, vol.Required(CONF_TYPE): cv.string, vol.Optional(CONF_VALUE_TEMPLATE, default=None): vol.Any(None, cv.string), }, extra=vol.REMOVE_EXTRA, ) DEVICE_TRIGGERS = "mqtt_device_triggers" @attr.s(slots=True) class TriggerInstance: """Attached trigger settings.""" action: AutomationActionType = attr.ib() automation_info: AutomationTriggerInfo = attr.ib() trigger: Trigger = attr.ib() remove: CALLBACK_TYPE | None = attr.ib(default=None) async def async_attach_trigger(self): """Attach MQTT trigger.""" mqtt_config = { mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN, mqtt_trigger.CONF_TOPIC: self.trigger.topic, mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING, mqtt_trigger.CONF_QOS: self.trigger.qos, } if self.trigger.payload: mqtt_config[CONF_PAYLOAD] = self.trigger.payload if self.trigger.value_template: mqtt_config[CONF_VALUE_TEMPLATE] = self.trigger.value_template mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config) if self.remove: self.remove() self.remove = await mqtt_trigger.async_attach_trigger( self.trigger.hass, mqtt_config, self.action, self.automation_info, ) @attr.s(slots=True) class Trigger: """Device trigger settings.""" device_id: str = attr.ib() discovery_data: dict | None = attr.ib() hass: HomeAssistant = attr.ib() payload: str | None = attr.ib() qos: int | None = attr.ib() remove_signal: Callable[[], None] | None = attr.ib() subtype: str = attr.ib() topic: str | None = attr.ib() type: str = attr.ib() value_template: str | None = attr.ib() trigger_instances: list[TriggerInstance] = attr.ib(factory=list) async def add_trigger(self, action, automation_info): """Add MQTT trigger.""" instance = TriggerInstance(action, automation_info, self) self.trigger_instances.append(instance) if self.topic is not None: # If we know about the trigger, subscribe to MQTT topic await instance.async_attach_trigger() @callback def async_remove() -> None: """Remove trigger.""" if instance not in self.trigger_instances: raise HomeAssistantError("Can't remove trigger twice") if instance.remove: instance.remove() self.trigger_instances.remove(instance) return async_remove async def update_trigger(self, config, discovery_hash, remove_signal): """Update MQTT device trigger.""" self.remove_signal = remove_signal self.type = config[CONF_TYPE] self.subtype = config[CONF_SUBTYPE] self.payload = config[CONF_PAYLOAD] self.qos = config[CONF_QOS] topic_changed = self.topic != config[CONF_TOPIC] self.topic = config[CONF_TOPIC] self.value_template = config[CONF_VALUE_TEMPLATE] # Unsubscribe+subscribe if this trigger is in use and topic has changed # If topic is same unsubscribe+subscribe will execute in the wrong order # because unsubscribe is done with help of async_create_task if topic_changed: for trig in self.trigger_instances: await trig.async_attach_trigger() def detach_trigger(self): """Remove MQTT device trigger.""" # Mark trigger as unknown self.topic = None # Unsubscribe if this trigger is in use for trig in self.trigger_instances: if trig.remove: trig.remove() trig.remove = None async def _update_device(hass, config_entry, config): """Update device registry.""" device_registry = await hass.helpers.device_registry.async_get_registry() config_entry_id = config_entry.entry_id device_info = device_info_from_config(config[CONF_DEVICE]) if config_entry_id is not None and device_info is not None: device_info["config_entry_id"] = config_entry_id device_registry.async_get_or_create(**device_info) async def async_setup_trigger(hass, config, config_entry, discovery_data): """Set up the MQTT device trigger.""" config = TRIGGER_DISCOVERY_SCHEMA(config) discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] remove_signal = None async def discovery_update(payload): """Handle discovery update.""" _LOGGER.info( "Got update for trigger with hash: %s '%s'", discovery_hash, payload ) if not payload: # Empty payload: Remove trigger _LOGGER.info("Removing trigger: %s", discovery_hash) debug_info.remove_trigger_discovery_data(hass, discovery_hash) if discovery_id in hass.data[DEVICE_TRIGGERS]: device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] device_trigger.detach_trigger() clear_discovery_hash(hass, discovery_hash) remove_signal() await cleanup_device_registry(hass, device.id) else: # Non-empty payload: Update trigger _LOGGER.info("Updating trigger: %s", discovery_hash) debug_info.update_trigger_discovery_data(hass, discovery_hash, payload) config = TRIGGER_DISCOVERY_SCHEMA(payload) await _update_device(hass, config_entry, config) device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] await device_trigger.update_trigger(config, discovery_hash, remove_signal) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) remove_signal = async_dispatcher_connect( hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update ) await _update_device(hass, config_entry, config) device_registry = await hass.helpers.device_registry.async_get_registry() device = device_registry.async_get_device( {(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]}, {tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]}, ) if device is None: async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) return if DEVICE_TRIGGERS not in hass.data: hass.data[DEVICE_TRIGGERS] = {} if discovery_id not in hass.data[DEVICE_TRIGGERS]: hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass=hass, device_id=device.id, discovery_data=discovery_data, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=config[CONF_TOPIC], payload=config[CONF_PAYLOAD], qos=config[CONF_QOS], remove_signal=remove_signal, value_template=config[CONF_VALUE_TEMPLATE], ) else: await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( config, discovery_hash, remove_signal ) debug_info.add_trigger_discovery_data( hass, discovery_hash, discovery_data, device.id ) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async def async_device_removed(hass: HomeAssistant, device_id: str): """Handle the removal of a device.""" triggers = await async_get_triggers(hass, device_id) for trig in triggers: device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID]) if device_trigger: discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH] discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC] debug_info.remove_trigger_discovery_data(hass, discovery_hash) device_trigger.detach_trigger() clear_discovery_hash(hass, discovery_hash) device_trigger.remove_signal() mqtt.publish( hass, discovery_topic, "", retain=True, ) async def async_get_triggers( hass: HomeAssistant, device_id: str ) -> list[dict[str, Any]]: """List device triggers for MQTT devices.""" triggers: list[dict] = [] if DEVICE_TRIGGERS not in hass.data: return triggers for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items(): if trig.device_id != device_id or trig.topic is None: continue trigger = { **MQTT_TRIGGER_BASE, "device_id": device_id, "type": trig.type, "subtype": trig.subtype, "discovery_id": discovery_id, } triggers.append(trigger) return triggers async def async_attach_trigger( hass: HomeAssistant, config: ConfigType, action: AutomationActionType, automation_info: AutomationTriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" if DEVICE_TRIGGERS not in hass.data: hass.data[DEVICE_TRIGGERS] = {} device_id = config[CONF_DEVICE_ID] discovery_id = config[CONF_DISCOVERY_ID] if discovery_id not in hass.data[DEVICE_TRIGGERS]: hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass=hass, device_id=device_id, discovery_data=None, remove_signal=None, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=None, payload=None, qos=None, value_template=None, ) return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger( action, automation_info )