Add MqttData helper to mqtt (#78825)
* Add MqttData helper to mqtt * Adjust client for circular dependencies * Move MqttData to models.py * Move get_mqtt_data to util.pypull/78848/head
parent
6b3c91bd6a
commit
e58531f118
|
@ -76,7 +76,6 @@ from .const import ( # noqa: F401
|
||||||
PLATFORMS,
|
PLATFORMS,
|
||||||
RELOADABLE_PLATFORMS,
|
RELOADABLE_PLATFORMS,
|
||||||
)
|
)
|
||||||
from .mixins import MqttData
|
|
||||||
from .models import ( # noqa: F401
|
from .models import ( # noqa: F401
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
|
@ -86,6 +85,7 @@ from .models import ( # noqa: F401
|
||||||
)
|
)
|
||||||
from .util import (
|
from .util import (
|
||||||
_VALID_QOS_SCHEMA,
|
_VALID_QOS_SCHEMA,
|
||||||
|
get_mqtt_data,
|
||||||
mqtt_config_entry_enabled,
|
mqtt_config_entry_enabled,
|
||||||
valid_publish_topic,
|
valid_publish_topic,
|
||||||
valid_subscribe_topic,
|
valid_subscribe_topic,
|
||||||
|
@ -164,7 +164,7 @@ async def _async_setup_discovery(
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Start the MQTT protocol service."""
|
"""Start the MQTT protocol service."""
|
||||||
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
|
mqtt_data = get_mqtt_data(hass, True)
|
||||||
|
|
||||||
conf: ConfigType | None = config.get(DOMAIN)
|
conf: ConfigType | None = config.get(DOMAIN)
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
|
||||||
|
|
||||||
Causes for this is config entry options changing.
|
Causes for this is config entry options changing.
|
||||||
"""
|
"""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
assert (client := mqtt_data.client) is not None
|
assert (client := mqtt_data.client) is not None
|
||||||
|
|
||||||
if (conf := mqtt_data.config) is None:
|
if (conf := mqtt_data.config) is None:
|
||||||
|
@ -267,7 +267,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
|
||||||
|
|
||||||
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
|
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
|
||||||
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
|
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
if mqtt_data.reload_entry:
|
if mqtt_data.reload_entry:
|
||||||
hass_config = await conf_util.async_hass_config_yaml(hass)
|
hass_config = await conf_util.async_hass_config_yaml(hass)
|
||||||
mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
|
mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
|
||||||
|
@ -307,7 +307,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Load a config entry."""
|
"""Load a config entry."""
|
||||||
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
|
mqtt_data = get_mqtt_data(hass, True)
|
||||||
|
|
||||||
# Merge basic configuration, and add missing defaults for basic options
|
# Merge basic configuration, and add missing defaults for basic options
|
||||||
if (conf := await async_fetch_config(hass, entry)) is None:
|
if (conf := await async_fetch_config(hass, entry)) is None:
|
||||||
|
@ -593,7 +593,7 @@ def async_subscribe_connection_status(
|
||||||
|
|
||||||
def is_connected(hass: HomeAssistant) -> bool:
|
def is_connected(hass: HomeAssistant) -> bool:
|
||||||
"""Return if MQTT client is connected."""
|
"""Return if MQTT client is connected."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
assert mqtt_data.client is not None
|
assert mqtt_data.client is not None
|
||||||
return mqtt_data.client.connected
|
return mqtt_data.client.connected
|
||||||
|
|
||||||
|
@ -611,7 +611,7 @@ async def async_remove_config_entry_device(
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload MQTT dump and publish service when the config entry is unloaded."""
|
"""Unload MQTT dump and publish service when the config entry is unloaded."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
assert mqtt_data.client is not None
|
assert mqtt_data.client is not None
|
||||||
mqtt_client = mqtt_data.client
|
mqtt_client = mqtt_data.client
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,6 @@ from .const import (
|
||||||
CONF_KEEPALIVE,
|
CONF_KEEPALIVE,
|
||||||
CONF_TLS_INSECURE,
|
CONF_TLS_INSECURE,
|
||||||
CONF_WILL_MESSAGE,
|
CONF_WILL_MESSAGE,
|
||||||
DATA_MQTT,
|
|
||||||
DEFAULT_ENCODING,
|
DEFAULT_ENCODING,
|
||||||
DEFAULT_QOS,
|
DEFAULT_QOS,
|
||||||
MQTT_CONNECTED,
|
MQTT_CONNECTED,
|
||||||
|
@ -61,15 +60,13 @@ from .models import (
|
||||||
ReceiveMessage,
|
ReceiveMessage,
|
||||||
ReceivePayloadType,
|
ReceivePayloadType,
|
||||||
)
|
)
|
||||||
from .util import mqtt_config_entry_enabled
|
from .util import get_mqtt_data, mqtt_config_entry_enabled
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# Only import for paho-mqtt type checking here, imports are done locally
|
# Only import for paho-mqtt type checking here, imports are done locally
|
||||||
# because integrations should be able to optionally rely on MQTT.
|
# because integrations should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
from .mixins import MqttData
|
|
||||||
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -100,11 +97,7 @@ async def async_publish(
|
||||||
encoding: str | None = DEFAULT_ENCODING,
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Publish message to a MQTT topic."""
|
"""Publish message to a MQTT topic."""
|
||||||
# Local import to avoid circular dependencies
|
mqtt_data = get_mqtt_data(hass, True)
|
||||||
# pylint: disable-next=import-outside-toplevel
|
|
||||||
from .mixins import MqttData
|
|
||||||
|
|
||||||
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
|
|
||||||
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
|
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Cannot publish to topic '{topic}', MQTT is not enabled"
|
f"Cannot publish to topic '{topic}', MQTT is not enabled"
|
||||||
|
@ -190,11 +183,7 @@ async def async_subscribe(
|
||||||
|
|
||||||
Call the return value to unsubscribe.
|
Call the return value to unsubscribe.
|
||||||
"""
|
"""
|
||||||
# Local import to avoid circular dependencies
|
mqtt_data = get_mqtt_data(hass, True)
|
||||||
# pylint: disable-next=import-outside-toplevel
|
|
||||||
from .mixins import MqttData
|
|
||||||
|
|
||||||
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
|
|
||||||
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
|
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
|
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
|
||||||
|
@ -332,7 +321,7 @@ class MQTT:
|
||||||
# should be able to optionally rely on MQTT.
|
# should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
|
self._mqtt_data = get_mqtt_data(hass)
|
||||||
|
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
|
|
|
@ -30,14 +30,12 @@ from .const import (
|
||||||
CONF_BIRTH_MESSAGE,
|
CONF_BIRTH_MESSAGE,
|
||||||
CONF_BROKER,
|
CONF_BROKER,
|
||||||
CONF_WILL_MESSAGE,
|
CONF_WILL_MESSAGE,
|
||||||
DATA_MQTT,
|
|
||||||
DEFAULT_BIRTH,
|
DEFAULT_BIRTH,
|
||||||
DEFAULT_DISCOVERY,
|
DEFAULT_DISCOVERY,
|
||||||
DEFAULT_WILL,
|
DEFAULT_WILL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from .mixins import MqttData
|
from .util import MQTT_WILL_BIRTH_SCHEMA, get_mqtt_data
|
||||||
from .util import MQTT_WILL_BIRTH_SCHEMA
|
|
||||||
|
|
||||||
MQTT_TIMEOUT = 5
|
MQTT_TIMEOUT = 5
|
||||||
|
|
||||||
|
@ -165,7 +163,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Manage the MQTT broker configuration."""
|
"""Manage the MQTT broker configuration."""
|
||||||
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
|
mqtt_data = get_mqtt_data(self.hass, True)
|
||||||
errors = {}
|
errors = {}
|
||||||
current_config = self.config_entry.data
|
current_config = self.config_entry.data
|
||||||
yaml_config = mqtt_data.config or {}
|
yaml_config = mqtt_data.config or {}
|
||||||
|
@ -216,7 +214,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Manage the MQTT options."""
|
"""Manage the MQTT options."""
|
||||||
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
|
mqtt_data = get_mqtt_data(self.hass, True)
|
||||||
errors = {}
|
errors = {}
|
||||||
current_config = self.config_entry.data
|
current_config = self.config_entry.data
|
||||||
yaml_config = mqtt_data.config or {}
|
yaml_config = mqtt_data.config or {}
|
||||||
|
@ -351,7 +349,7 @@ def try_connection(
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
# Get the config from configuration.yaml
|
# Get the config from configuration.yaml
|
||||||
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
|
mqtt_data = get_mqtt_data(hass, True)
|
||||||
yaml_config = mqtt_data.config or {}
|
yaml_config = mqtt_data.config or {}
|
||||||
entry_config = {
|
entry_config = {
|
||||||
CONF_BROKER: broker,
|
CONF_BROKER: broker,
|
||||||
|
|
|
@ -33,17 +33,16 @@ from .const import (
|
||||||
CONF_PAYLOAD,
|
CONF_PAYLOAD,
|
||||||
CONF_QOS,
|
CONF_QOS,
|
||||||
CONF_TOPIC,
|
CONF_TOPIC,
|
||||||
DATA_MQTT,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from .discovery import MQTT_DISCOVERY_DONE
|
from .discovery import MQTT_DISCOVERY_DONE
|
||||||
from .mixins import (
|
from .mixins import (
|
||||||
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
|
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
|
||||||
MqttData,
|
|
||||||
MqttDiscoveryDeviceUpdate,
|
MqttDiscoveryDeviceUpdate,
|
||||||
send_discovery_done,
|
send_discovery_done,
|
||||||
update_device,
|
update_device,
|
||||||
)
|
)
|
||||||
|
from .util import get_mqtt_data
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -203,7 +202,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.discovery_data = discovery_data
|
self.discovery_data = discovery_data
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
|
self._mqtt_data = get_mqtt_data(hass)
|
||||||
|
|
||||||
MqttDiscoveryDeviceUpdate.__init__(
|
MqttDiscoveryDeviceUpdate.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -281,7 +280,7 @@ async def async_setup_trigger(
|
||||||
|
|
||||||
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
|
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
|
||||||
"""Handle Mqtt removed from a device."""
|
"""Handle Mqtt removed from a device."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
triggers = await async_get_triggers(hass, device_id)
|
triggers = await async_get_triggers(hass, device_id)
|
||||||
for trig in triggers:
|
for trig in triggers:
|
||||||
device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
|
device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
|
||||||
|
@ -296,7 +295,7 @@ async def async_get_triggers(
|
||||||
hass: HomeAssistant, device_id: str
|
hass: HomeAssistant, device_id: str
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""List device triggers for MQTT devices."""
|
"""List device triggers for MQTT devices."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
triggers: list[dict[str, str]] = []
|
triggers: list[dict[str, str]] = []
|
||||||
|
|
||||||
if not mqtt_data.device_triggers:
|
if not mqtt_data.device_triggers:
|
||||||
|
@ -325,7 +324,7 @@ async def async_attach_trigger(
|
||||||
trigger_info: TriggerInfo,
|
trigger_info: TriggerInfo,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Attach a trigger."""
|
"""Attach a trigger."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
device_id = config[CONF_DEVICE_ID]
|
device_id = config[CONF_DEVICE_ID]
|
||||||
discovery_id = config[CONF_DISCOVERY_ID]
|
discovery_id = config[CONF_DISCOVERY_ID]
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,8 @@ from homeassistant.core import HomeAssistant, callback, split_entity_id
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
from homeassistant.helpers.device_registry import DeviceEntry
|
from homeassistant.helpers.device_registry import DeviceEntry
|
||||||
|
|
||||||
from . import DATA_MQTT, MQTT, debug_info, is_connected
|
from . import debug_info, is_connected
|
||||||
|
from .util import get_mqtt_data
|
||||||
|
|
||||||
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
|
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
|
||||||
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
|
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
|
||||||
|
@ -43,7 +44,8 @@ def _async_get_diagnostics(
|
||||||
device: DeviceEntry | None = None,
|
device: DeviceEntry | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return diagnostics for a config entry."""
|
"""Return diagnostics for a config entry."""
|
||||||
mqtt_instance: MQTT = hass.data[DATA_MQTT].client
|
mqtt_instance = get_mqtt_data(hass).client
|
||||||
|
assert mqtt_instance is not None
|
||||||
|
|
||||||
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)
|
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import functools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -29,12 +28,9 @@ from .const import (
|
||||||
ATTR_DISCOVERY_TOPIC,
|
ATTR_DISCOVERY_TOPIC,
|
||||||
CONF_AVAILABILITY,
|
CONF_AVAILABILITY,
|
||||||
CONF_TOPIC,
|
CONF_TOPIC,
|
||||||
DATA_MQTT,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
|
from .util import get_mqtt_data
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .mixins import MqttData
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -98,7 +94,7 @@ async def async_start( # noqa: C901
|
||||||
hass: HomeAssistant, discovery_topic, config_entry=None
|
hass: HomeAssistant, discovery_topic, config_entry=None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start MQTT Discovery."""
|
"""Start MQTT Discovery."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
mqtt_integrations = {}
|
mqtt_integrations = {}
|
||||||
|
|
||||||
async def async_discovery_message_received(msg):
|
async def async_discovery_message_received(msg):
|
||||||
|
|
|
@ -4,10 +4,9 @@ from __future__ import annotations
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, cast, final
|
from typing import Any, Protocol, cast, final
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -29,13 +28,7 @@ from homeassistant.const import (
|
||||||
CONF_UNIQUE_ID,
|
CONF_UNIQUE_ID,
|
||||||
CONF_VALUE_TEMPLATE,
|
CONF_VALUE_TEMPLATE,
|
||||||
)
|
)
|
||||||
from homeassistant.core import (
|
from homeassistant.core import Event, HomeAssistant, async_get_hass, callback
|
||||||
CALLBACK_TYPE,
|
|
||||||
Event,
|
|
||||||
HomeAssistant,
|
|
||||||
async_get_hass,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
|
@ -60,7 +53,7 @@ from homeassistant.helpers.json import json_loads
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from . import debug_info, subscription
|
from . import debug_info, subscription
|
||||||
from .client import MQTT, Subscription, async_publish
|
from .client import async_publish
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_DISCOVERY_HASH,
|
ATTR_DISCOVERY_HASH,
|
||||||
ATTR_DISCOVERY_PAYLOAD,
|
ATTR_DISCOVERY_PAYLOAD,
|
||||||
|
@ -69,7 +62,6 @@ from .const import (
|
||||||
CONF_ENCODING,
|
CONF_ENCODING,
|
||||||
CONF_QOS,
|
CONF_QOS,
|
||||||
CONF_TOPIC,
|
CONF_TOPIC,
|
||||||
DATA_MQTT,
|
|
||||||
DEFAULT_ENCODING,
|
DEFAULT_ENCODING,
|
||||||
DEFAULT_PAYLOAD_AVAILABLE,
|
DEFAULT_PAYLOAD_AVAILABLE,
|
||||||
DEFAULT_PAYLOAD_NOT_AVAILABLE,
|
DEFAULT_PAYLOAD_NOT_AVAILABLE,
|
||||||
|
@ -91,10 +83,7 @@ from .subscription import (
|
||||||
async_subscribe_topics,
|
async_subscribe_topics,
|
||||||
async_unsubscribe_topics,
|
async_unsubscribe_topics,
|
||||||
)
|
)
|
||||||
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
|
from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .device_trigger import Trigger
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -272,27 +261,6 @@ def warn_for_legacy_schema(domain: str) -> Callable:
|
||||||
return validator
|
return validator
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MqttData:
|
|
||||||
"""Keep the MQTT entry data."""
|
|
||||||
|
|
||||||
client: MQTT | None = None
|
|
||||||
config: ConfigType | None = None
|
|
||||||
device_triggers: dict[str, Trigger] = field(default_factory=dict)
|
|
||||||
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
|
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
last_discovery: float = 0.0
|
|
||||||
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
|
|
||||||
reload_entry: bool = False
|
|
||||||
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
|
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
reload_needed: bool = False
|
|
||||||
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
|
|
||||||
updated_config: ConfigType = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class SetupEntity(Protocol):
|
class SetupEntity(Protocol):
|
||||||
"""Protocol type for async_setup_entities."""
|
"""Protocol type for async_setup_entities."""
|
||||||
|
|
||||||
|
@ -313,8 +281,7 @@ async def async_get_platform_config_from_yaml(
|
||||||
config_yaml: ConfigType | None = None,
|
config_yaml: ConfigType | None = None,
|
||||||
) -> list[ConfigType]:
|
) -> list[ConfigType]:
|
||||||
"""Return a list of validated configurations for the domain."""
|
"""Return a list of validated configurations for the domain."""
|
||||||
|
mqtt_data = get_mqtt_data(hass)
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
|
||||||
if config_yaml is None:
|
if config_yaml is None:
|
||||||
config_yaml = mqtt_data.config
|
config_yaml = mqtt_data.config
|
||||||
if not config_yaml:
|
if not config_yaml:
|
||||||
|
@ -331,7 +298,7 @@ async def async_setup_entry_helper(
|
||||||
discovery_schema: vol.Schema,
|
discovery_schema: vol.Schema,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
|
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
|
|
||||||
async def async_discover(discovery_payload):
|
async def async_discover(discovery_payload):
|
||||||
"""Discover and add an MQTT entity, automation or tag."""
|
"""Discover and add an MQTT entity, automation or tag."""
|
||||||
|
@ -363,7 +330,7 @@ async def async_setup_entry_helper(
|
||||||
|
|
||||||
async def _async_setup_entities() -> None:
|
async def _async_setup_entities() -> None:
|
||||||
"""Set up MQTT items from configuration.yaml."""
|
"""Set up MQTT items from configuration.yaml."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
if mqtt_data.updated_config:
|
if mqtt_data.updated_config:
|
||||||
# The platform has been reloaded
|
# The platform has been reloaded
|
||||||
config_yaml = mqtt_data.updated_config
|
config_yaml = mqtt_data.updated_config
|
||||||
|
@ -395,7 +362,7 @@ async def async_setup_platform_helper(
|
||||||
async_setup_entities: SetupEntity,
|
async_setup_entities: SetupEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Help to set up the platform for manual configured MQTT entities."""
|
"""Help to set up the platform for manual configured MQTT entities."""
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
if mqtt_data.reload_entry:
|
if mqtt_data.reload_entry:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
|
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
|
||||||
|
@ -621,7 +588,7 @@ class MqttAvailability(Entity):
|
||||||
@property
|
@property
|
||||||
def available(self) -> bool:
|
def available(self) -> bool:
|
||||||
"""Return if the device is available."""
|
"""Return if the device is available."""
|
||||||
mqtt_data: MqttData = self.hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(self.hass)
|
||||||
assert mqtt_data.client is not None
|
assert mqtt_data.client is not None
|
||||||
client = mqtt_data.client
|
client = mqtt_data.client
|
||||||
if not client.connected and not self.hass.is_stopping:
|
if not client.connected and not self.hass.is_stopping:
|
||||||
|
@ -844,7 +811,7 @@ class MqttDiscoveryUpdate(Entity):
|
||||||
self._removed_from_hass = False
|
self._removed_from_hass = False
|
||||||
if discovery_data is None:
|
if discovery_data is None:
|
||||||
return
|
return
|
||||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
mqtt_data = get_mqtt_data(hass)
|
||||||
self._registry_hooks = mqtt_data.discovery_registry_hooks
|
self._registry_hooks = mqtt_data.discovery_registry_hooks
|
||||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||||
if discovery_hash in self._registry_hooks:
|
if discovery_hash in self._registry_hooks:
|
||||||
|
|
|
@ -3,17 +3,22 @@ from __future__ import annotations
|
||||||
|
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
from typing import Any, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
|
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||||
from homeassistant.helpers import template
|
from homeassistant.helpers import template
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
||||||
from homeassistant.helpers.typing import TemplateVarsType
|
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import MQTT, Subscription
|
||||||
|
from .device_trigger import Trigger
|
||||||
|
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
@ -174,3 +179,24 @@ class MqttValueTemplate:
|
||||||
return self._value_template.async_render_with_possible_json_value(
|
return self._value_template.async_render_with_possible_json_value(
|
||||||
payload, default, variables=values
|
payload, default, variables=values
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MqttData:
|
||||||
|
"""Keep the MQTT entry data."""
|
||||||
|
|
||||||
|
client: MQTT | None = None
|
||||||
|
config: ConfigType | None = None
|
||||||
|
device_triggers: dict[str, Trigger] = field(default_factory=dict)
|
||||||
|
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
last_discovery: float = 0.0
|
||||||
|
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||||
|
reload_entry: bool = False
|
||||||
|
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
reload_needed: bool = False
|
||||||
|
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
|
||||||
|
updated_config: ConfigType = field(default_factory=dict)
|
||||||
|
|
|
@ -15,10 +15,12 @@ from .const import (
|
||||||
ATTR_QOS,
|
ATTR_QOS,
|
||||||
ATTR_RETAIN,
|
ATTR_RETAIN,
|
||||||
ATTR_TOPIC,
|
ATTR_TOPIC,
|
||||||
|
DATA_MQTT,
|
||||||
DEFAULT_QOS,
|
DEFAULT_QOS,
|
||||||
DEFAULT_RETAIN,
|
DEFAULT_RETAIN,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
|
from .models import MqttData
|
||||||
|
|
||||||
|
|
||||||
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
|
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
|
||||||
|
@ -111,3 +113,10 @@ MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
|
||||||
},
|
},
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
|
||||||
|
"""Return typed MqttData from hass.data[DATA_MQTT]."""
|
||||||
|
if ensure_exists:
|
||||||
|
return hass.data.setdefault(DATA_MQTT, MqttData())
|
||||||
|
return hass.data[DATA_MQTT]
|
||||||
|
|
Loading…
Reference in New Issue