Make hass.data["mqtt"] an instance of a DataClass (#77972)

* Use dataclass to reference hass.data globals

* Add discovery_registry_hooks to dataclass

* Move discovery registry hooks to dataclass

* Add device triggers to dataclass

* Cleanup DEVICE_TRIGGERS const

* Add last_discovery to data_class

* Simplify typing for class `Subscription`

* Follow up on comment

* Redo suggested typing change to sasisfy mypy

* Restore typing

* Add mypy version to CI check logging

* revert changes to ci.yaml

* Add docstr for protocol

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Mypy update after merging #78399

* Remove mypy ignore

* Correct return type

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
pull/78916/head
Jan Bouwhuis 2022-09-17 21:43:42 +02:00 committed by Paulus Schoutsen
parent 4894e2e5a4
commit 3dd0dbf38f
12 changed files with 174 additions and 137 deletions

View File

@ -20,13 +20,7 @@ from homeassistant.const import (
CONF_USERNAME,
SERVICE_RELOAD,
)
from homeassistant.core import (
CALLBACK_TYPE,
HassJob,
HomeAssistant,
ServiceCall,
callback,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import (
config_validation as cv,
@ -71,15 +65,7 @@ from .const import ( # noqa: F401
CONF_TLS_VERSION,
CONF_TOPIC,
CONF_WILL_MESSAGE,
CONFIG_ENTRY_IS_SETUP,
DATA_MQTT,
DATA_MQTT_CONFIG,
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
DATA_MQTT_RELOAD_DISPATCHERS,
DATA_MQTT_RELOAD_ENTRY,
DATA_MQTT_RELOAD_NEEDED,
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE,
DATA_MQTT_UPDATED_CONFIG,
DEFAULT_ENCODING,
DEFAULT_QOS,
DEFAULT_RETAIN,
@ -89,7 +75,7 @@ from .const import ( # noqa: F401
PLATFORMS,
RELOADABLE_PLATFORMS,
)
from .mixins import async_discover_yaml_entities
from .mixins import MqttData, async_discover_yaml_entities
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttValueTemplate,
@ -177,6 +163,8 @@ async def _async_setup_discovery(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Start the MQTT protocol service."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
conf: ConfigType | None = config.get(DOMAIN)
websocket_api.async_register_command(hass, websocket_subscribe)
@ -185,7 +173,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if conf:
conf = dict(conf)
hass.data[DATA_MQTT_CONFIG] = conf
mqtt_data.config = conf
if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is None:
# Create an import flow if the user has yaml configured entities etc.
@ -197,12 +185,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={},
)
hass.data[DATA_MQTT_RELOAD_NEEDED] = True
mqtt_data.reload_needed = True
elif mqtt_entry_status is False:
_LOGGER.info(
"MQTT will be not available until the config entry is enabled",
)
hass.data[DATA_MQTT_RELOAD_NEEDED] = True
mqtt_data.reload_needed = True
return True
@ -260,33 +248,34 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
Causes for this is config entry options changing.
"""
mqtt_client = hass.data[DATA_MQTT]
mqtt_data: MqttData = hass.data[DATA_MQTT]
assert (client := mqtt_data.client) is not None
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None:
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data))
mqtt_client.conf = _merge_extended_config(entry, conf)
await mqtt_client.async_disconnect()
mqtt_client.init_client()
await mqtt_client.async_connect()
mqtt_data.config = _merge_extended_config(entry, conf)
await client.async_disconnect()
client.init_client()
await client.async_connect()
await discovery.async_stop(hass)
if mqtt_client.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, mqtt_client.conf, entry)
if client.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, cast(ConfigType, mqtt_data.config), entry)
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
if DATA_MQTT_RELOAD_ENTRY in hass.data:
mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.reload_entry:
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
hass.data[DATA_MQTT_CONFIG] = mqtt_config
mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)
# Merge basic configuration, and add missing defaults for basic options
_merge_basic_config(hass, entry, hass.data.get(DATA_MQTT_CONFIG, {}))
_merge_basic_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
@ -294,7 +283,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None:
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data))
# User has configuration.yaml config, warn about config entry overrides
@ -317,21 +306,20 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
# Merge basic configuration, and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out
return False
hass.data[DATA_MQTT_DISCOVERY_REGISTRY_HOOKS] = {}
hass.data[DATA_MQTT] = MQTT(hass, entry, conf)
mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions
if DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE in hass.data:
hass.data[DATA_MQTT].subscriptions = hass.data.pop(
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE
)
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore
mqtt_data.subscriptions_to_restore = []
entry.add_update_listener(_async_config_entry_updated)
await hass.data[DATA_MQTT].async_connect()
await mqtt_data.client.async_connect()
async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""
@ -380,7 +368,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
return
await hass.data[DATA_MQTT].async_publish(msg_topic, payload, qos, retain)
assert mqtt_data.client is not None and msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)
hass.services.async_register(
DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
@ -421,7 +410,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
# setup platforms and discovery
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
async def async_setup_reload_service() -> None:
"""Create the reload service for the MQTT domain."""
@ -435,7 +423,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Reload the modern yaml platforms
config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {}
hass.data[DATA_MQTT_UPDATED_CONFIG] = config_yaml.get(DOMAIN, {})
mqtt_data.updated_config = config_yaml.get(DOMAIN, {})
await asyncio.gather(
*(
[
@ -476,13 +464,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Setup reload service after all platforms have loaded
await async_setup_reload_service()
# When the entry is reloaded, also reload manual set up items to enable MQTT
if DATA_MQTT_RELOAD_ENTRY in hass.data:
hass.data.pop(DATA_MQTT_RELOAD_ENTRY)
if mqtt_data.reload_entry:
mqtt_data.reload_entry = False
reload_manual_setup = True
# When the entry was disabled before, reload manual set up items to enable MQTT again
if DATA_MQTT_RELOAD_NEEDED in hass.data:
hass.data.pop(DATA_MQTT_RELOAD_NEEDED)
if mqtt_data.reload_needed:
mqtt_data.reload_needed = False
reload_manual_setup = True
if reload_manual_setup:
@ -592,7 +580,9 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
return hass.data[DATA_MQTT].connected
mqtt_data: MqttData = hass.data[DATA_MQTT]
assert mqtt_data.client is not None
return mqtt_data.client.connected
async def async_remove_config_entry_device(
@ -608,6 +598,10 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client
# Unload publish and dump services.
hass.services.async_remove(
DOMAIN,
@ -620,7 +614,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Stop the discovery
await discovery.async_stop(hass)
mqtt_client: MQTT = hass.data[DATA_MQTT]
# Unload the platforms
await asyncio.gather(
*(
@ -630,26 +623,23 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
await hass.async_block_till_done()
# Unsubscribe reload dispatchers
while reload_dispatchers := hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []):
while reload_dispatchers := mqtt_data.reload_dispatchers:
reload_dispatchers.pop()()
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
# Cleanup listeners
mqtt_client.cleanup()
# Trigger reload manual MQTT items at entry setup
if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is False:
# The entry is disabled reload legacy manual items when the entry is enabled again
hass.data[DATA_MQTT_RELOAD_NEEDED] = True
mqtt_data.reload_needed = True
elif mqtt_entry_status is True:
# The entry is reloaded:
# Trigger re-fetching the yaml config at entry setup
hass.data[DATA_MQTT_RELOAD_ENTRY] = True
mqtt_data.reload_entry = True
# Reload the legacy yaml platform to make entities unavailable
await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS)
# Cleanup entity registry hooks
registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS
]
registry_hooks = mqtt_data.discovery_registry_hooks
while registry_hooks:
registry_hooks.popitem()[1]()
# Wait for all ACKs and stop the loop
@ -657,6 +647,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Store remaining subscriptions to be able to restore or reload them
# when the entry is set up again
if mqtt_client.subscriptions:
hass.data[DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE] = mqtt_client.subscriptions
mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions
return True

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
@ -17,6 +17,7 @@ import attr
import certifi
from paho.mqtt.client import MQTTMessage
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_CLIENT_ID,
CONF_PASSWORD,
@ -52,7 +53,6 @@ from .const import (
MQTT_DISCONNECTED,
PROTOCOL_31,
)
from .discovery import LAST_DISCOVERY
from .models import (
AsyncMessageCallbackType,
MessageCallbackType,
@ -68,6 +68,9 @@ if TYPE_CHECKING:
# because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__)
DISCOVERY_COOLDOWN = 2
@ -97,8 +100,12 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass):
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
)
@ -126,11 +133,13 @@ async def async_publish(
)
return
await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain)
await mqtt_data.client.async_publish(
topic, outgoing_payload, qos or 0, retain or False
)
AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Awaitable[None]
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
@ -175,13 +184,18 @@ async def async_subscribe(
| DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str | None = "utf-8",
encoding: str | None = DEFAULT_ENCODING,
):
"""Subscribe to an MQTT topic.
Call the return value to unsubscribe.
"""
if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass):
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
@ -206,7 +220,7 @@ async def async_subscribe(
cast(DeprecatedMessageCallbackType, msg_callback)
)
async_remove = await hass.data[DATA_MQTT].async_subscribe(
async_remove = await mqtt_data.client.async_subscribe(
topic,
catch_log_exception(
wrapped_msg_callback,
@ -309,15 +323,17 @@ class MQTT:
def __init__(
self,
hass,
config_entry,
conf,
hass: HomeAssistant,
config_entry: ConfigEntry,
conf: ConfigType,
) -> None:
"""Initialize Home Assistant MQTT client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
self.hass = hass
self.config_entry = config_entry
self.conf = conf
@ -634,7 +650,6 @@ class MQTT:
subscription.job,
)
continue
self.hass.async_run_hass_job(
subscription.job,
ReceiveMessage(
@ -694,10 +709,10 @@ class MQTT:
async def _discovery_cooldown(self):
now = time.time()
# Reset discovery and subscribe cooldowns
self.hass.data[LAST_DISCOVERY] = now
self._mqtt_data.last_discovery = now
self._last_subscribe = now
last_discovery = self.hass.data[LAST_DISCOVERY]
last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
@ -705,7 +720,7 @@ class MQTT:
while now < wait_until:
await asyncio.sleep(wait_until - now)
now = time.time()
last_discovery = self.hass.data[LAST_DISCOVERY]
last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN

View File

@ -18,7 +18,7 @@ from homeassistant.const import (
CONF_PROTOCOL,
CONF_USERNAME,
)
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from .client import MqttClientSetup
@ -30,12 +30,13 @@ from .const import (
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_WILL_MESSAGE,
DATA_MQTT_CONFIG,
DATA_MQTT,
DEFAULT_BIRTH,
DEFAULT_DISCOVERY,
DEFAULT_WILL,
DOMAIN,
)
from .mixins import MqttData
from .util import MQTT_WILL_BIRTH_SCHEMA
MQTT_TIMEOUT = 5
@ -164,9 +165,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the MQTT broker configuration."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
errors = {}
current_config = self.config_entry.data
yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {})
yaml_config = mqtt_data.config or {}
if user_input is not None:
can_connect = await self.hass.async_add_executor_job(
try_connection,
@ -214,9 +216,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the MQTT options."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
errors = {}
current_config = self.config_entry.data
yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {})
yaml_config = mqtt_data.config or {}
options_config: dict[str, Any] = {}
if user_input is not None:
bad_birth = False
@ -334,14 +337,22 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
)
def try_connection(hass, broker, port, username, password, protocol="3.1"):
def try_connection(
hass: HomeAssistant,
broker: str,
port: int,
username: str | None,
password: str | None,
protocol: str = "3.1",
) -> bool:
"""Test if we can connect to an MQTT broker."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml
yaml_config = hass.data.get(DATA_MQTT_CONFIG, {})
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
yaml_config = mqtt_data.config or {}
entry_config = {
CONF_BROKER: broker,
CONF_PORT: port,
@ -351,7 +362,7 @@ def try_connection(hass, broker, port, username, password, protocol="3.1"):
}
client = MqttClientSetup({**yaml_config, **entry_config}).client
result = queue.Queue(maxsize=1)
result: queue.Queue[bool] = queue.Queue(maxsize=1)
def on_connect(client_, userdata, flags, result_code):
"""Handle connection result."""

View File

@ -30,16 +30,8 @@ CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version"
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
DATA_MQTT = "mqtt"
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE = "mqtt_client_subscriptions"
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS = "mqtt_discovery_registry_hooks"
DATA_MQTT_CONFIG = "mqtt_config"
MQTT_DATA_DEVICE_TRACKER_LEGACY = "mqtt_device_tracker_legacy"
DATA_MQTT_RELOAD_DISPATCHERS = "mqtt_reload_dispatchers"
DATA_MQTT_RELOAD_ENTRY = "mqtt_reload_entry"
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
DATA_MQTT_UPDATED_CONFIG = "mqtt_updated_config"
DEFAULT_PREFIX = "homeassistant"
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"

View File

@ -33,11 +33,13 @@ from .const import (
CONF_PAYLOAD,
CONF_QOS,
CONF_TOPIC,
DATA_MQTT,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE
from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttData,
MqttDiscoveryDeviceUpdate,
send_discovery_done,
update_device,
@ -81,8 +83,6 @@ TRIGGER_DISCOVERY_SCHEMA = MQTT_BASE_SCHEMA.extend(
extra=vol.REMOVE_EXTRA,
)
DEVICE_TRIGGERS = "mqtt_device_triggers"
LOG_NAME = "Device trigger"
@ -203,6 +203,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
MqttDiscoveryDeviceUpdate.__init__(
self,
@ -217,8 +218,8 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
"""Initialize the device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}):
self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
if discovery_id not in self._mqtt_data.device_triggers:
self._mqtt_data.device_triggers[discovery_id] = Trigger(
hass=self.hass,
device_id=self.device_id,
discovery_data=self.discovery_data,
@ -230,7 +231,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
value_template=self._config[CONF_VALUE_TEMPLATE],
)
else:
await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
await self._mqtt_data.device_triggers[discovery_id].update_trigger(
self._config
)
debug_info.add_trigger_discovery_data(
@ -246,16 +247,16 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
)
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
update_device(self.hass, self._config_entry, config)
device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
device_trigger: Trigger = self._mqtt_data.device_triggers[discovery_id]
await device_trigger.update_trigger(config)
async def async_tear_down(self) -> None:
"""Cleanup device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id in self.hass.data[DEVICE_TRIGGERS]:
if discovery_id in self._mqtt_data.device_triggers:
_LOGGER.info("Removing trigger: %s", discovery_hash)
trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
trigger: Trigger = self._mqtt_data.device_triggers[discovery_id]
trigger.detach_trigger()
debug_info.remove_trigger_discovery_data(self.hass, discovery_hash)
@ -280,11 +281,10 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop(
trig[CONF_DISCOVERY_ID]
)
device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
if device_trigger:
device_trigger.detach_trigger()
discovery_data = cast(dict, device_trigger.discovery_data)
@ -296,12 +296,13 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device triggers for MQTT devices."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
triggers: list[dict[str, str]] = []
if DEVICE_TRIGGERS not in hass.data:
if not mqtt_data.device_triggers:
return triggers
for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items():
for discovery_id, trig in mqtt_data.device_triggers.items():
if trig.device_id != device_id or trig.topic is None:
continue
@ -324,12 +325,12 @@ async def async_attach_trigger(
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
hass.data.setdefault(DEVICE_TRIGGERS, {})
mqtt_data: MqttData = hass.data[DATA_MQTT]
device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID]
if discovery_id not in hass.data[DEVICE_TRIGGERS]:
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
if discovery_id not in mqtt_data.device_triggers:
mqtt_data.device_triggers[discovery_id] = Trigger(
hass=hass,
device_id=device_id,
discovery_data=None,
@ -340,6 +341,6 @@ async def async_attach_trigger(
qos=None,
value_template=None,
)
return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger(
return await mqtt_data.device_triggers[discovery_id].add_trigger(
action, trigger_info
)

View File

@ -43,7 +43,7 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None,
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
mqtt_instance: MQTT = hass.data[DATA_MQTT]
mqtt_instance: MQTT = hass.data[DATA_MQTT].client
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)

View File

@ -7,6 +7,7 @@ import functools
import logging
import re
import time
from typing import TYPE_CHECKING
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant
@ -28,9 +29,13 @@ from .const import (
ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY,
CONF_TOPIC,
DATA_MQTT,
DOMAIN,
)
if TYPE_CHECKING:
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__)
TOPIC_MATCHER = re.compile(
@ -69,7 +74,6 @@ INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~"
@ -80,12 +84,12 @@ class MQTTConfig(dict):
discovery_data: dict
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None:
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry in ALREADY_DISCOVERED list."""
del hass.data[ALREADY_DISCOVERED][discovery_hash]
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple):
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]):
"""Clear entry in ALREADY_DISCOVERED list."""
hass.data[ALREADY_DISCOVERED][discovery_hash] = {}
@ -94,11 +98,12 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None
) -> None:
"""Start MQTT Discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_integrations = {}
async def async_discovery_message_received(msg):
"""Process the received message."""
hass.data[LAST_DISCOVERY] = time.time()
mqtt_data.last_discovery = time.time()
payload = msg.payload
topic = msg.topic
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
@ -253,7 +258,7 @@ async def async_start( # noqa: C901
)
)
hass.data[LAST_DISCOVERY] = time.time()
mqtt_data.last_discovery = time.time()
mqtt_integrations = await async_get_mqtt(hass)
hass.data[INTEGRATION_UNSUBSCRIBE] = {}

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from functools import partial
import logging
from typing import Any, Protocol, cast, final
from typing import TYPE_CHECKING, Any, Protocol, cast, final
import voluptuous as vol
@ -60,7 +61,7 @@ from homeassistant.helpers.json import json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription
from .client import async_publish
from .client import MQTT, Subscription, async_publish
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD,
@ -70,11 +71,6 @@ from .const import (
CONF_QOS,
CONF_TOPIC,
DATA_MQTT,
DATA_MQTT_CONFIG,
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
DATA_MQTT_RELOAD_DISPATCHERS,
DATA_MQTT_RELOAD_ENTRY,
DATA_MQTT_UPDATED_CONFIG,
DEFAULT_ENCODING,
DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE,
@ -98,6 +94,9 @@ from .subscription import (
)
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
if TYPE_CHECKING:
from .device_trigger import Trigger
_LOGGER = logging.getLogger(__name__)
AVAILABILITY_ALL = "all"
@ -274,6 +273,24 @@ def warn_for_legacy_schema(domain: str) -> Callable:
return validator
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)
class SetupEntity(Protocol):
"""Protocol type for async_setup_entities."""
@ -292,11 +309,12 @@ async def async_discover_yaml_entities(
hass: HomeAssistant, platform_domain: str
) -> None:
"""Discover entities for a platform."""
if DATA_MQTT_UPDATED_CONFIG in hass.data:
mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.updated_config:
# The platform has been reloaded
config_yaml = hass.data[DATA_MQTT_UPDATED_CONFIG]
config_yaml = mqtt_data.updated_config
else:
config_yaml = hass.data.get(DATA_MQTT_CONFIG, {})
config_yaml = mqtt_data.config or {}
if not config_yaml:
return
if platform_domain not in config_yaml:
@ -318,8 +336,9 @@ async def async_get_platform_config_from_yaml(
) -> list[ConfigType]:
"""Return a list of validated configurations for the domain."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
if config_yaml is None:
config_yaml = hass.data.get(DATA_MQTT_CONFIG)
config_yaml = mqtt_data.config
if not config_yaml:
return []
if not (platform_configs := config_yaml.get(platform_domain)):
@ -334,6 +353,7 @@ async def async_setup_entry_helper(
schema: vol.Schema,
) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
async def async_discover(discovery_payload):
"""Discover and add an MQTT entity, automation or tag."""
@ -357,7 +377,7 @@ async def async_setup_entry_helper(
)
raise
hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []).append(
mqtt_data.reload_dispatchers.append(
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover
)
@ -372,7 +392,8 @@ async def async_setup_platform_helper(
async_setup_entities: SetupEntity,
) -> None:
"""Help to set up the platform for manual configured MQTT entities."""
if DATA_MQTT_RELOAD_ENTRY in hass.data:
mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.reload_entry:
_LOGGER.debug(
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
platform_domain,
@ -597,7 +618,10 @@ class MqttAvailability(Entity):
@property
def available(self) -> bool:
"""Return if the device is available."""
if not self.hass.data[DATA_MQTT].connected and not self.hass.is_stopping:
mqtt_data: MqttData = self.hass.data[DATA_MQTT]
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False
if not self._avail_topics:
return True
@ -632,7 +656,7 @@ async def cleanup_device_registry(
)
def get_discovery_hash(discovery_data: dict) -> tuple:
def get_discovery_hash(discovery_data: dict) -> tuple[str, str]:
"""Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH]
@ -817,9 +841,8 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False
if discovery_data is None:
return
self._registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS
]
mqtt_data: MqttData = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks:
self._registry_hooks.pop(discovery_hash)()
@ -897,7 +920,7 @@ class MqttDiscoveryUpdate(Entity):
def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform."""
if self._discovery_data is not None:
discovery_hash: tuple = self._discovery_data[ATTR_DISCOVERY_HASH]
discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
if self.registry_entry is not None:
self._registry_hooks[
discovery_hash

View File

@ -369,7 +369,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
if isinstance(payload, str):
payload = payload.encode("utf-8")
msg = ReceiveMessage(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg)
hass.data["mqtt"].client._mqtt_handle_message(msg)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)

View File

@ -155,7 +155,7 @@ async def test_manual_config_set(
assert await async_setup_component(hass, "mqtt", {"mqtt": {"broker": "bla"}})
await hass.async_block_till_done()
# do not try to reload
del hass.data["mqtt_reload_needed"]
hass.data["mqtt"].reload_needed = False
assert len(mock_finish_setup.mock_calls) == 0
mock_try_connection.return_value = True

View File

@ -1438,7 +1438,7 @@ async def test_clean_up_registry_monitoring(
):
"""Test registry monitoring hook is removed after a reload."""
await mqtt_mock_entry_no_yaml_config()
hooks: dict = hass.data[mqtt.const.DATA_MQTT_DISCOVERY_REGISTRY_HOOKS]
hooks: dict = hass.data["mqtt"].discovery_registry_hooks
# discover an entity that is not enabled by default
config1 = {
"name": "sbfspot_12345",

View File

@ -1776,14 +1776,14 @@ async def test_delayed_birth_message(
await hass.async_block_till_done()
mqtt_component_mock = MagicMock(
return_value=hass.data["mqtt"],
spec_set=hass.data["mqtt"],
wraps=hass.data["mqtt"],
return_value=hass.data["mqtt"].client,
spec_set=hass.data["mqtt"].client,
wraps=hass.data["mqtt"].client,
)
mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock
mqtt_mock = hass.data["mqtt"]
hass.data["mqtt"].client = mqtt_component_mock
mqtt_mock = hass.data["mqtt"].client
mqtt_mock.reset_mock()
async def wait_birth(topic, payload, qos):