Add MqttData helper to mqtt (#78825)

* Add MqttData helper to mqtt

* Adjust client for circular dependencies

* Move MqttData to models.py

* Move get_mqtt_data to util.py
pull/78848/head
epenet 2022-09-20 19:40:06 +02:00 committed by GitHub
parent 6b3c91bd6a
commit e58531f118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 74 additions and 88 deletions

View File

@ -76,7 +76,6 @@ from .const import ( # noqa: F401
PLATFORMS,
RELOADABLE_PLATFORMS,
)
from .mixins import MqttData
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttValueTemplate,
@ -86,6 +85,7 @@ from .models import ( # noqa: F401
)
from .util import (
_VALID_QOS_SCHEMA,
get_mqtt_data,
mqtt_config_entry_enabled,
valid_publish_topic,
valid_subscribe_topic,
@ -164,7 +164,7 @@ async def _async_setup_discovery(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Start the MQTT protocol service."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(hass, True)
conf: ConfigType | None = config.get(DOMAIN)
@ -249,7 +249,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
Causes for this is config entry options changing.
"""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
assert (client := mqtt_data.client) is not None
if (conf := mqtt_data.config) is None:
@ -267,7 +267,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
if mqtt_data.reload_entry:
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
@ -307,7 +307,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(hass, True)
# Merge basic configuration, and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
@ -593,7 +593,7 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
return mqtt_data.client.connected
@ -611,7 +611,7 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client

View File

@ -46,7 +46,6 @@ from .const import (
CONF_KEEPALIVE,
CONF_TLS_INSECURE,
CONF_WILL_MESSAGE,
DATA_MQTT,
DEFAULT_ENCODING,
DEFAULT_QOS,
MQTT_CONNECTED,
@ -61,15 +60,13 @@ from .models import (
ReceiveMessage,
ReceivePayloadType,
)
from .util import mqtt_config_entry_enabled
from .util import get_mqtt_data, mqtt_config_entry_enabled
if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally
# because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__)
@ -100,11 +97,7 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
@ -190,11 +183,7 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
@ -332,7 +321,7 @@ class MQTT:
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
self._mqtt_data = get_mqtt_data(hass)
self.hass = hass
self.config_entry = config_entry

View File

@ -30,14 +30,12 @@ from .const import (
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_WILL_MESSAGE,
DATA_MQTT,
DEFAULT_BIRTH,
DEFAULT_DISCOVERY,
DEFAULT_WILL,
DOMAIN,
)
from .mixins import MqttData
from .util import MQTT_WILL_BIRTH_SCHEMA
from .util import MQTT_WILL_BIRTH_SCHEMA, get_mqtt_data
MQTT_TIMEOUT = 5
@ -165,7 +163,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the MQTT broker configuration."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(self.hass, True)
errors = {}
current_config = self.config_entry.data
yaml_config = mqtt_data.config or {}
@ -216,7 +214,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the MQTT options."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(self.hass, True)
errors = {}
current_config = self.config_entry.data
yaml_config = mqtt_data.config or {}
@ -351,7 +349,7 @@ def try_connection(
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
mqtt_data = get_mqtt_data(hass, True)
yaml_config = mqtt_data.config or {}
entry_config = {
CONF_BROKER: broker,

View File

@ -33,17 +33,16 @@ from .const import (
CONF_PAYLOAD,
CONF_QOS,
CONF_TOPIC,
DATA_MQTT,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE
from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttData,
MqttDiscoveryDeviceUpdate,
send_discovery_done,
update_device,
)
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -203,7 +202,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
self._mqtt_data = get_mqtt_data(hass)
MqttDiscoveryDeviceUpdate.__init__(
self,
@ -281,7 +280,7 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
@ -296,7 +295,7 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device triggers for MQTT devices."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
triggers: list[dict[str, str]] = []
if not mqtt_data.device_triggers:
@ -325,7 +324,7 @@ async def async_attach_trigger(
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID]

View File

@ -16,7 +16,8 @@ from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceEntry
from . import DATA_MQTT, MQTT, debug_info, is_connected
from . import debug_info, is_connected
from .util import get_mqtt_data
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
@ -43,7 +44,8 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None,
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
mqtt_instance: MQTT = hass.data[DATA_MQTT].client
mqtt_instance = get_mqtt_data(hass).client
assert mqtt_instance is not None
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)

View File

@ -7,7 +7,6 @@ import functools
import logging
import re
import time
from typing import TYPE_CHECKING
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant
@ -29,12 +28,9 @@ from .const import (
ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY,
CONF_TOPIC,
DATA_MQTT,
DOMAIN,
)
if TYPE_CHECKING:
from .mixins import MqttData
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -98,7 +94,7 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None
) -> None:
"""Start MQTT Discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {}
async def async_discovery_message_received(msg):

View File

@ -4,10 +4,9 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from functools import partial
import logging
from typing import TYPE_CHECKING, Any, Protocol, cast, final
from typing import Any, Protocol, cast, final
import voluptuous as vol
@ -29,13 +28,7 @@ from homeassistant.const import (
CONF_UNIQUE_ID,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
async_get_hass,
callback,
)
from homeassistant.core import Event, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
@ -60,7 +53,7 @@ from homeassistant.helpers.json import json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription
from .client import MQTT, Subscription, async_publish
from .client import async_publish
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD,
@ -69,7 +62,6 @@ from .const import (
CONF_ENCODING,
CONF_QOS,
CONF_TOPIC,
DATA_MQTT,
DEFAULT_ENCODING,
DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE,
@ -91,10 +83,7 @@ from .subscription import (
async_subscribe_topics,
async_unsubscribe_topics,
)
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
if TYPE_CHECKING:
from .device_trigger import Trigger
from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -272,27 +261,6 @@ def warn_for_legacy_schema(domain: str) -> Callable:
return validator
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
default_factory=dict
)
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)
class SetupEntity(Protocol):
"""Protocol type for async_setup_entities."""
@ -313,8 +281,7 @@ async def async_get_platform_config_from_yaml(
config_yaml: ConfigType | None = None,
) -> list[ConfigType]:
"""Return a list of validated configurations for the domain."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
if config_yaml is None:
config_yaml = mqtt_data.config
if not config_yaml:
@ -331,7 +298,7 @@ async def async_setup_entry_helper(
discovery_schema: vol.Schema,
) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
async def async_discover(discovery_payload):
"""Discover and add an MQTT entity, automation or tag."""
@ -363,7 +330,7 @@ async def async_setup_entry_helper(
async def _async_setup_entities() -> None:
"""Set up MQTT items from configuration.yaml."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
if mqtt_data.updated_config:
# The platform has been reloaded
config_yaml = mqtt_data.updated_config
@ -395,7 +362,7 @@ async def async_setup_platform_helper(
async_setup_entities: SetupEntity,
) -> None:
"""Help to set up the platform for manual configured MQTT entities."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
if mqtt_data.reload_entry:
_LOGGER.debug(
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
@ -621,7 +588,7 @@ class MqttAvailability(Entity):
@property
def available(self) -> bool:
"""Return if the device is available."""
mqtt_data: MqttData = self.hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
@ -844,7 +811,7 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False
if discovery_data is None:
return
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_data = get_mqtt_data(hass)
self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks:

View File

@ -3,17 +3,22 @@ from __future__ import annotations
from ast import literal_eval
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
import datetime as dt
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union
import attr
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from .client import MQTT, Subscription
from .device_trigger import Trigger
_SENTINEL = object()
@ -174,3 +179,24 @@ class MqttValueTemplate:
return self._value_template.async_render_with_possible_json_value(
payload, default, variables=values
)
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
default_factory=dict
)
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)

View File

@ -15,10 +15,12 @@ from .const import (
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
DATA_MQTT,
DEFAULT_QOS,
DEFAULT_RETAIN,
DOMAIN,
)
from .models import MqttData
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
@ -111,3 +113,10 @@ MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
},
required=True,
)
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
if ensure_exists:
return hass.data.setdefault(DATA_MQTT, MqttData())
return hass.data[DATA_MQTT]