Add MqttData helper to mqtt (#78825)

* Add MqttData helper to mqtt

* Adjust client for circular dependencies

* Move MqttData to models.py

* Move get_mqtt_data to util.py
pull/78848/head
epenet 2022-09-20 19:40:06 +02:00 committed by GitHub
parent 6b3c91bd6a
commit e58531f118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 74 additions and 88 deletions

View File

@ -76,7 +76,6 @@ from .const import ( # noqa: F401
PLATFORMS, PLATFORMS,
RELOADABLE_PLATFORMS, RELOADABLE_PLATFORMS,
) )
from .mixins import MqttData
from .models import ( # noqa: F401 from .models import ( # noqa: F401
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@ -86,6 +85,7 @@ from .models import ( # noqa: F401
) )
from .util import ( from .util import (
_VALID_QOS_SCHEMA, _VALID_QOS_SCHEMA,
get_mqtt_data,
mqtt_config_entry_enabled, mqtt_config_entry_enabled,
valid_publish_topic, valid_publish_topic,
valid_subscribe_topic, valid_subscribe_topic,
@ -164,7 +164,7 @@ async def _async_setup_discovery(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Start the MQTT protocol service.""" """Start the MQTT protocol service."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) mqtt_data = get_mqtt_data(hass, True)
conf: ConfigType | None = config.get(DOMAIN) conf: ConfigType | None = config.get(DOMAIN)
@ -249,7 +249,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
Causes for this is config entry options changing. Causes for this is config entry options changing.
""" """
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
assert (client := mqtt_data.client) is not None assert (client := mqtt_data.client) is not None
if (conf := mqtt_data.config) is None: if (conf := mqtt_data.config) is None:
@ -267,7 +267,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None: async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry.""" """Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
if mqtt_data.reload_entry: if mqtt_data.reload_entry:
hass_config = await conf_util.async_hass_config_yaml(hass) hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
@ -307,7 +307,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry.""" """Load a config entry."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) mqtt_data = get_mqtt_data(hass, True)
# Merge basic configuration, and add missing defaults for basic options # Merge basic configuration, and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None: if (conf := await async_fetch_config(hass, entry)) is None:
@ -593,7 +593,7 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool: def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected.""" """Return if MQTT client is connected."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None assert mqtt_data.client is not None
return mqtt_data.client.connected return mqtt_data.client.connected
@ -611,7 +611,7 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded.""" """Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None assert mqtt_data.client is not None
mqtt_client = mqtt_data.client mqtt_client = mqtt_data.client

View File

@ -46,7 +46,6 @@ from .const import (
CONF_KEEPALIVE, CONF_KEEPALIVE,
CONF_TLS_INSECURE, CONF_TLS_INSECURE,
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
DATA_MQTT,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_QOS, DEFAULT_QOS,
MQTT_CONNECTED, MQTT_CONNECTED,
@ -61,15 +60,13 @@ from .models import (
ReceiveMessage, ReceiveMessage,
ReceivePayloadType, ReceivePayloadType,
) )
from .util import mqtt_config_entry_enabled from .util import get_mqtt_data, mqtt_config_entry_enabled
if TYPE_CHECKING: if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally # Only import for paho-mqtt type checking here, imports are done locally
# because integrations should be able to optionally rely on MQTT. # because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -100,11 +97,7 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> None: ) -> None:
"""Publish message to a MQTT topic.""" """Publish message to a MQTT topic."""
# Local import to avoid circular dependencies mqtt_data = get_mqtt_data(hass, True)
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled" f"Cannot publish to topic '{topic}', MQTT is not enabled"
@ -190,11 +183,7 @@ async def async_subscribe(
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
# Local import to avoid circular dependencies mqtt_data = get_mqtt_data(hass, True)
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled" f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
@ -332,7 +321,7 @@ class MQTT:
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self._mqtt_data: MqttData = hass.data[DATA_MQTT] self._mqtt_data = get_mqtt_data(hass)
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry

View File

@ -30,14 +30,12 @@ from .const import (
CONF_BIRTH_MESSAGE, CONF_BIRTH_MESSAGE,
CONF_BROKER, CONF_BROKER,
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
DATA_MQTT,
DEFAULT_BIRTH, DEFAULT_BIRTH,
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
DEFAULT_WILL, DEFAULT_WILL,
DOMAIN, DOMAIN,
) )
from .mixins import MqttData from .util import MQTT_WILL_BIRTH_SCHEMA, get_mqtt_data
from .util import MQTT_WILL_BIRTH_SCHEMA
MQTT_TIMEOUT = 5 MQTT_TIMEOUT = 5
@ -165,7 +163,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT broker configuration.""" """Manage the MQTT broker configuration."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) mqtt_data = get_mqtt_data(self.hass, True)
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = mqtt_data.config or {} yaml_config = mqtt_data.config or {}
@ -216,7 +214,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT options.""" """Manage the MQTT options."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) mqtt_data = get_mqtt_data(self.hass, True)
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = mqtt_data.config or {} yaml_config = mqtt_data.config or {}
@ -351,7 +349,7 @@ def try_connection(
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml # Get the config from configuration.yaml
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) mqtt_data = get_mqtt_data(hass, True)
yaml_config = mqtt_data.config or {} yaml_config = mqtt_data.config or {}
entry_config = { entry_config = {
CONF_BROKER: broker, CONF_BROKER: broker,

View File

@ -33,17 +33,16 @@ from .const import (
CONF_PAYLOAD, CONF_PAYLOAD,
CONF_QOS, CONF_QOS,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT,
DOMAIN, DOMAIN,
) )
from .discovery import MQTT_DISCOVERY_DONE from .discovery import MQTT_DISCOVERY_DONE
from .mixins import ( from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA, MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttData,
MqttDiscoveryDeviceUpdate, MqttDiscoveryDeviceUpdate,
send_discovery_done, send_discovery_done,
update_device, update_device,
) )
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -203,7 +202,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id self.device_id = device_id
self.discovery_data = discovery_data self.discovery_data = discovery_data
self.hass = hass self.hass = hass
self._mqtt_data: MqttData = hass.data[DATA_MQTT] self._mqtt_data = get_mqtt_data(hass)
MqttDiscoveryDeviceUpdate.__init__( MqttDiscoveryDeviceUpdate.__init__(
self, self,
@ -281,7 +280,7 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device.""" """Handle Mqtt removed from a device."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
triggers = await async_get_triggers(hass, device_id) triggers = await async_get_triggers(hass, device_id)
for trig in triggers: for trig in triggers:
device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID]) device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
@ -296,7 +295,7 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""List device triggers for MQTT devices.""" """List device triggers for MQTT devices."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
triggers: list[dict[str, str]] = [] triggers: list[dict[str, str]] = []
if not mqtt_data.device_triggers: if not mqtt_data.device_triggers:
@ -325,7 +324,7 @@ async def async_attach_trigger(
trigger_info: TriggerInfo, trigger_info: TriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
device_id = config[CONF_DEVICE_ID] device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID] discovery_id = config[CONF_DISCOVERY_ID]

View File

@ -16,7 +16,8 @@ from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.device_registry import DeviceEntry
from . import DATA_MQTT, MQTT, debug_info, is_connected from . import debug_info, is_connected
from .util import get_mqtt_data
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME} REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE} REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
@ -43,7 +44,8 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None, device: DeviceEntry | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
mqtt_instance: MQTT = hass.data[DATA_MQTT].client mqtt_instance = get_mqtt_data(hass).client
assert mqtt_instance is not None
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG) redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)

View File

@ -7,7 +7,6 @@ import functools
import logging import logging
import re import re
import time import time
from typing import TYPE_CHECKING
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -29,12 +28,9 @@ from .const import (
ATTR_DISCOVERY_TOPIC, ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY, CONF_AVAILABILITY,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT,
DOMAIN, DOMAIN,
) )
from .util import get_mqtt_data
if TYPE_CHECKING:
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -98,7 +94,7 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic, config_entry=None
) -> None: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {} mqtt_integrations = {}
async def async_discovery_message_received(msg): async def async_discovery_message_received(msg):

View File

@ -4,10 +4,9 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from functools import partial from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, cast, final from typing import Any, Protocol, cast, final
import voluptuous as vol import voluptuous as vol
@ -29,13 +28,7 @@ from homeassistant.const import (
CONF_UNIQUE_ID, CONF_UNIQUE_ID,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
) )
from homeassistant.core import ( from homeassistant.core import Event, HomeAssistant, async_get_hass, callback
CALLBACK_TYPE,
Event,
HomeAssistant,
async_get_hass,
callback,
)
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
@ -60,7 +53,7 @@ from homeassistant.helpers.json import json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription from . import debug_info, subscription
from .client import MQTT, Subscription, async_publish from .client import async_publish
from .const import ( from .const import (
ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_PAYLOAD,
@ -69,7 +62,6 @@ from .const import (
CONF_ENCODING, CONF_ENCODING,
CONF_QOS, CONF_QOS,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE,
@ -91,10 +83,7 @@ from .subscription import (
async_subscribe_topics, async_subscribe_topics,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
from .util import mqtt_config_entry_enabled, valid_subscribe_topic from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic
if TYPE_CHECKING:
from .device_trigger import Trigger
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -272,27 +261,6 @@ def warn_for_legacy_schema(domain: str) -> Callable:
return validator return validator
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
default_factory=dict
)
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)
class SetupEntity(Protocol): class SetupEntity(Protocol):
"""Protocol type for async_setup_entities.""" """Protocol type for async_setup_entities."""
@ -313,8 +281,7 @@ async def async_get_platform_config_from_yaml(
config_yaml: ConfigType | None = None, config_yaml: ConfigType | None = None,
) -> list[ConfigType]: ) -> list[ConfigType]:
"""Return a list of validated configurations for the domain.""" """Return a list of validated configurations for the domain."""
mqtt_data = get_mqtt_data(hass)
mqtt_data: MqttData = hass.data[DATA_MQTT]
if config_yaml is None: if config_yaml is None:
config_yaml = mqtt_data.config config_yaml = mqtt_data.config
if not config_yaml: if not config_yaml:
@ -331,7 +298,7 @@ async def async_setup_entry_helper(
discovery_schema: vol.Schema, discovery_schema: vol.Schema,
) -> None: ) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery.""" """Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT entity, automation or tag.""" """Discover and add an MQTT entity, automation or tag."""
@ -363,7 +330,7 @@ async def async_setup_entry_helper(
async def _async_setup_entities() -> None: async def _async_setup_entities() -> None:
"""Set up MQTT items from configuration.yaml.""" """Set up MQTT items from configuration.yaml."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
if mqtt_data.updated_config: if mqtt_data.updated_config:
# The platform has been reloaded # The platform has been reloaded
config_yaml = mqtt_data.updated_config config_yaml = mqtt_data.updated_config
@ -395,7 +362,7 @@ async def async_setup_platform_helper(
async_setup_entities: SetupEntity, async_setup_entities: SetupEntity,
) -> None: ) -> None:
"""Help to set up the platform for manual configured MQTT entities.""" """Help to set up the platform for manual configured MQTT entities."""
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
if mqtt_data.reload_entry: if mqtt_data.reload_entry:
_LOGGER.debug( _LOGGER.debug(
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry", "MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
@ -621,7 +588,7 @@ class MqttAvailability(Entity):
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return if the device is available.""" """Return if the device is available."""
mqtt_data: MqttData = self.hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None assert mqtt_data.client is not None
client = mqtt_data.client client = mqtt_data.client
if not client.connected and not self.hass.is_stopping: if not client.connected and not self.hass.is_stopping:
@ -844,7 +811,7 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False self._removed_from_hass = False
if discovery_data is None: if discovery_data is None:
return return
mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_data = get_mqtt_data(hass)
self._registry_hooks = mqtt_data.discovery_registry_hooks self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks: if discovery_hash in self._registry_hooks:

View File

@ -3,17 +3,22 @@ from __future__ import annotations
from ast import literal_eval from ast import literal_eval
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
import datetime as dt import datetime as dt
from typing import Any, Union from typing import TYPE_CHECKING, Any, Union
import attr import attr
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from .client import MQTT, Subscription
from .device_trigger import Trigger
_SENTINEL = object() _SENTINEL = object()
@ -174,3 +179,24 @@ class MqttValueTemplate:
return self._value_template.async_render_with_possible_json_value( return self._value_template.async_render_with_possible_json_value(
payload, default, variables=values payload, default, variables=values
) )
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
default_factory=dict
)
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)

View File

@ -15,10 +15,12 @@ from .const import (
ATTR_QOS, ATTR_QOS,
ATTR_RETAIN, ATTR_RETAIN,
ATTR_TOPIC, ATTR_TOPIC,
DATA_MQTT,
DEFAULT_QOS, DEFAULT_QOS,
DEFAULT_RETAIN, DEFAULT_RETAIN,
DOMAIN, DOMAIN,
) )
from .models import MqttData
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
@ -111,3 +113,10 @@ MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
}, },
required=True, required=True,
) )
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
if ensure_exists:
return hass.data.setdefault(DATA_MQTT, MqttData())
return hass.data[DATA_MQTT]