Add MQTT integration discovery (#41332)
* Add MQTT integration discovery * Add script/hassfest/mqtt.py * Unsubscribe if config entry exists * Add homeassistant/generated/mqtt.py * Fix bad loop * Improve tests * Improve tests * Apply suggestions from code review Co-authored-by: Fabian Affolter <mail@fabian-affolter.ch> * Prevent initiating multiple config flows Co-authored-by: Fabian Affolter <mail@fabian-affolter.ch>pull/41414/head
parent
3f263d5cbe
commit
343e5d64b8
|
@ -69,6 +69,7 @@ from .const import (
|
|||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
DEFAULT_WILL,
|
||||
DOMAIN,
|
||||
MQTT_CONNECTED,
|
||||
MQTT_DISCONNECTED,
|
||||
PROTOCOL_311,
|
||||
|
@ -86,8 +87,6 @@ from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN = "mqtt"
|
||||
|
||||
DATA_MQTT = "mqtt"
|
||||
|
||||
SERVICE_PUBLISH = "publish"
|
||||
|
|
|
@ -41,6 +41,8 @@ DEFAULT_WILL = {
|
|||
ATTR_RETAIN: DEFAULT_RETAIN,
|
||||
}
|
||||
|
||||
DOMAIN = "mqtt"
|
||||
|
||||
MQTT_CONNECTED = "mqtt_connected"
|
||||
MQTT_DISCONNECTED = "mqtt_disconnected"
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Support for MQTT discovery."""
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
@ -9,9 +10,15 @@ from homeassistant.components import mqtt
|
|||
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.loader import async_get_mqtt
|
||||
|
||||
from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS
|
||||
from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC
|
||||
from .const import (
|
||||
ATTR_DISCOVERY_HASH,
|
||||
ATTR_DISCOVERY_PAYLOAD,
|
||||
ATTR_DISCOVERY_TOPIC,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,7 +46,9 @@ SUPPORTED_COMPONENTS = [
|
|||
ALREADY_DISCOVERED = "mqtt_discovered_components"
|
||||
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
|
||||
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
||||
DATA_CONFIG_FLOW_LOCK = "mqtt_discovery_config_flow_lock"
|
||||
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
||||
INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe"
|
||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||
LAST_DISCOVERY = "mqtt_last_discovery"
|
||||
|
@ -65,8 +74,9 @@ async def async_start(
|
|||
hass: HomeAssistantType, discovery_topic, config_entry=None
|
||||
) -> bool:
|
||||
"""Start MQTT Discovery."""
|
||||
mqtt_integrations = {}
|
||||
|
||||
async def async_device_message_received(msg):
|
||||
async def async_entity_message_received(msg):
|
||||
"""Process the received message."""
|
||||
hass.data[LAST_DISCOVERY] = time.time()
|
||||
payload = msg.payload
|
||||
|
@ -172,12 +182,52 @@ async def async_start(
|
|||
)
|
||||
|
||||
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
|
||||
hass.data[DATA_CONFIG_FLOW_LOCK] = asyncio.Lock()
|
||||
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
|
||||
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
||||
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
||||
hass, f"{discovery_topic}/#", async_entity_message_received, 0
|
||||
)
|
||||
hass.data[LAST_DISCOVERY] = time.time()
|
||||
mqtt_integrations = await async_get_mqtt(hass)
|
||||
|
||||
hass.data[INTEGRATION_UNSUBSCRIBE] = {}
|
||||
|
||||
for (integration, topics) in mqtt_integrations.items():
|
||||
|
||||
async def async_integration_message_received(integration, msg):
|
||||
"""Process the received message."""
|
||||
key = f"{integration}_{msg.subscribed_topic}"
|
||||
|
||||
# Lock to prevent initiating many parallel config flows.
|
||||
# Note: The lock is not intended to prevent a race, only for performance
|
||||
async with hass.data[DATA_CONFIG_FLOW_LOCK]:
|
||||
# Already unsubscribed
|
||||
if key not in hass.data[INTEGRATION_UNSUBSCRIBE]:
|
||||
return
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
integration, context={"source": DOMAIN}, data=msg
|
||||
)
|
||||
if (
|
||||
result
|
||||
and result["type"] == "abort"
|
||||
and result["reason"]
|
||||
in ["already_configured", "single_instance_allowed"]
|
||||
):
|
||||
unsub = hass.data[INTEGRATION_UNSUBSCRIBE].pop(key, None)
|
||||
if unsub is None:
|
||||
return
|
||||
unsub()
|
||||
|
||||
for topic in topics:
|
||||
key = f"{integration}_{topic}"
|
||||
hass.data[INTEGRATION_UNSUBSCRIBE][key] = await mqtt.async_subscribe(
|
||||
hass,
|
||||
topic,
|
||||
functools.partial(async_integration_message_received, integration),
|
||||
0,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -187,3 +237,7 @@ async def async_stop(hass: HomeAssistantType) -> bool:
|
|||
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE]()
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE] = None
|
||||
if INTEGRATION_UNSUBSCRIBE in hass.data:
|
||||
for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()):
|
||||
unsub()
|
||||
hass.data[INTEGRATION_UNSUBSCRIBE].pop(key)
|
||||
|
|
|
@ -21,36 +21,72 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize flow."""
|
||||
self._prefix = DEFAULT_PREFIX
|
||||
|
||||
async def async_step_mqtt(self, discovery_info=None):
|
||||
"""Handle a flow initialized by MQTT discovery."""
|
||||
if self._async_in_progress() or self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
await self.async_set_unique_id(DOMAIN)
|
||||
|
||||
# Validate the topic, will throw if it fails
|
||||
prefix = discovery_info.subscribed_topic
|
||||
if prefix.endswith("/#"):
|
||||
prefix = prefix[:-2]
|
||||
try:
|
||||
valid_subscribe_topic(f"{prefix}/#")
|
||||
except vol.Invalid:
|
||||
return self.async_abort(reason="invalid_discovery_info")
|
||||
|
||||
self._prefix = prefix
|
||||
|
||||
return await self.async_step_confirm()
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle a flow initialized by the user."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
return await self.async_step_config()
|
||||
if self.show_advanced_options:
|
||||
return await self.async_step_config()
|
||||
return await self.async_step_confirm()
|
||||
|
||||
async def async_step_config(self, user_input=None):
|
||||
"""Confirm the setup."""
|
||||
errors = {}
|
||||
data = {CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX}
|
||||
data = {CONF_DISCOVERY_PREFIX: self._prefix}
|
||||
|
||||
if user_input is not None:
|
||||
bad_prefix = False
|
||||
if self.show_advanced_options:
|
||||
prefix = user_input[CONF_DISCOVERY_PREFIX]
|
||||
try:
|
||||
valid_subscribe_topic(f"{prefix}/#")
|
||||
except vol.Invalid:
|
||||
errors["base"] = "invalid_discovery_topic"
|
||||
bad_prefix = True
|
||||
else:
|
||||
data = user_input
|
||||
prefix = user_input[CONF_DISCOVERY_PREFIX]
|
||||
if prefix.endswith("/#"):
|
||||
prefix = prefix[:-2]
|
||||
try:
|
||||
valid_subscribe_topic(f"{prefix}/#")
|
||||
except vol.Invalid:
|
||||
errors["base"] = "invalid_discovery_topic"
|
||||
bad_prefix = True
|
||||
else:
|
||||
data[CONF_DISCOVERY_PREFIX] = prefix
|
||||
if not bad_prefix:
|
||||
return self.async_create_entry(title="Tasmota", data=data)
|
||||
|
||||
fields = {}
|
||||
if self.show_advanced_options:
|
||||
fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX)] = str
|
||||
fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=self._prefix)] = str
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="config", data_schema=vol.Schema(fields), errors=errors
|
||||
)
|
||||
|
||||
async def async_step_confirm(self, user_input=None):
|
||||
"""Confirm the setup."""
|
||||
|
||||
data = {CONF_DISCOVERY_PREFIX: self._prefix}
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(title="Tasmota", data=data)
|
||||
|
||||
return self.async_show_form(step_id="confirm")
|
||||
|
|
|
@ -5,5 +5,6 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/tasmota",
|
||||
"requirements": ["hatasmota==0.0.10"],
|
||||
"dependencies": ["mqtt"],
|
||||
"mqtt": ["tasmota/discovery/#"],
|
||||
"codeowners": ["@emontnemery"]
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"confirm": {
|
||||
"description": "Do you want to set up Tasmota?"
|
||||
},
|
||||
"config": {
|
||||
"title": "Tasmota",
|
||||
"description": "Please enter the Tasmota configuration.",
|
||||
|
|
|
@ -25,6 +25,7 @@ SOURCE_HASSIO = "hassio"
|
|||
SOURCE_HOMEKIT = "homekit"
|
||||
SOURCE_IMPORT = "import"
|
||||
SOURCE_INTEGRATION_DISCOVERY = "integration_discovery"
|
||||
SOURCE_MQTT = "mqtt"
|
||||
SOURCE_SSDP = "ssdp"
|
||||
SOURCE_USER = "user"
|
||||
SOURCE_ZEROCONF = "zeroconf"
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
"""Automatically generated by hassfest.
|
||||
|
||||
To update, run python3 -m script.hassfest
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
|
||||
MQTT = {
|
||||
"tasmota": [
|
||||
"tasmota/discovery/#"
|
||||
]
|
||||
}
|
|
@ -80,6 +80,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
|
|||
|
||||
async_step_zeroconf = async_step_discovery
|
||||
async_step_ssdp = async_step_discovery
|
||||
async_step_mqtt = async_step_discovery
|
||||
async_step_homekit = async_step_discovery
|
||||
|
||||
async def async_step_import(self, _: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
|
|
|
@ -302,6 +302,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
|||
return await self.async_step_pick_implementation()
|
||||
|
||||
async_step_user = async_step_pick_implementation
|
||||
async_step_mqtt = async_step_discovery
|
||||
async_step_ssdp = async_step_discovery
|
||||
async_step_zeroconf = async_step_discovery
|
||||
async_step_homekit = async_step_discovery
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from homeassistant.generated.mqtt import MQTT
|
||||
from homeassistant.generated.ssdp import SSDP
|
||||
from homeassistant.generated.zeroconf import HOMEKIT, ZEROCONF
|
||||
|
||||
|
@ -202,6 +203,21 @@ async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List]:
|
|||
return ssdp
|
||||
|
||||
|
||||
async def async_get_mqtt(hass: "HomeAssistant") -> Dict[str, List]:
|
||||
"""Return cached list of MQTT mappings."""
|
||||
|
||||
mqtt: Dict[str, List] = MQTT.copy()
|
||||
|
||||
integrations = await async_get_custom_components(hass)
|
||||
for integration in integrations.values():
|
||||
if not integration.mqtt:
|
||||
continue
|
||||
|
||||
mqtt[integration.domain] = integration.mqtt
|
||||
|
||||
return mqtt
|
||||
|
||||
|
||||
class Integration:
|
||||
"""An integration in Home Assistant."""
|
||||
|
||||
|
@ -323,6 +339,11 @@ class Integration:
|
|||
"""Return Integration Quality Scale."""
|
||||
return cast(str, self.manifest.get("quality_scale"))
|
||||
|
||||
@property
|
||||
def mqtt(self) -> Optional[list]:
|
||||
"""Return Integration MQTT entries."""
|
||||
return cast(List[dict], self.manifest.get("mqtt"))
|
||||
|
||||
@property
|
||||
def ssdp(self) -> Optional[list]:
|
||||
"""Return Integration SSDP entries."""
|
||||
|
|
|
@ -15,6 +15,7 @@ DATA_INTEGRATIONS_WITH_REQS = "integrations_with_reqs"
|
|||
CONSTRAINT_FILE = "package_constraints.txt"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DISCOVERY_INTEGRATIONS: Dict[str, Iterable[str]] = {
|
||||
"mqtt": ("mqtt",),
|
||||
"ssdp": ("ssdp",),
|
||||
"zeroconf": ("zeroconf", "homekit"),
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ from . import (
|
|||
dependencies,
|
||||
json,
|
||||
manifest,
|
||||
mqtt,
|
||||
requirements,
|
||||
services,
|
||||
ssdp,
|
||||
|
@ -25,6 +26,7 @@ INTEGRATION_PLUGINS = [
|
|||
config_flow,
|
||||
dependencies,
|
||||
manifest,
|
||||
mqtt,
|
||||
services,
|
||||
ssdp,
|
||||
translations,
|
||||
|
|
|
@ -33,6 +33,11 @@ def validate_integration(config: Config, integration: Integration):
|
|||
"config_flow",
|
||||
"HomeKit information in a manifest requires a config flow to exist",
|
||||
)
|
||||
if integration.manifest.get("mqtt"):
|
||||
integration.add_error(
|
||||
"config_flow",
|
||||
"MQTT information in a manifest requires a config flow to exist",
|
||||
)
|
||||
if integration.manifest.get("ssdp"):
|
||||
integration.add_error(
|
||||
"config_flow",
|
||||
|
@ -51,6 +56,7 @@ def validate_integration(config: Config, integration: Integration):
|
|||
"async_step_discovery" in config_flow
|
||||
or "async_step_hassio" in config_flow
|
||||
or "async_step_homekit" in config_flow
|
||||
or "async_step_mqtt" in config_flow
|
||||
or "async_step_ssdp" in config_flow
|
||||
or "async_step_zeroconf" in config_flow
|
||||
)
|
||||
|
@ -91,6 +97,7 @@ def generate_and_validate(integrations: Dict[str, Integration], config: Config):
|
|||
if not (
|
||||
integration.manifest.get("config_flow")
|
||||
or integration.manifest.get("homekit")
|
||||
or integration.manifest.get("mqtt")
|
||||
or integration.manifest.get("ssdp")
|
||||
or integration.manifest.get("zeroconf")
|
||||
):
|
||||
|
|
|
@ -38,6 +38,7 @@ MANIFEST_SCHEMA = vol.Schema(
|
|||
vol.Required("domain"): str,
|
||||
vol.Required("name"): str,
|
||||
vol.Optional("config_flow"): bool,
|
||||
vol.Optional("mqtt"): [str],
|
||||
vol.Optional("zeroconf"): [
|
||||
vol.Any(
|
||||
str,
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
"""Generate MQTT file."""
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from .model import Config, Integration
|
||||
|
||||
BASE = """
|
||||
\"\"\"Automatically generated by hassfest.
|
||||
|
||||
To update, run python3 -m script.hassfest
|
||||
\"\"\"
|
||||
|
||||
# fmt: off
|
||||
|
||||
MQTT = {}
|
||||
""".strip()
|
||||
|
||||
|
||||
def generate_and_validate(integrations: Dict[str, Integration]):
|
||||
"""Validate and generate MQTT data."""
|
||||
|
||||
data = defaultdict(list)
|
||||
|
||||
for domain in sorted(integrations):
|
||||
integration = integrations[domain]
|
||||
|
||||
if not integration.manifest:
|
||||
continue
|
||||
|
||||
mqtt = integration.manifest.get("mqtt")
|
||||
|
||||
if not mqtt:
|
||||
continue
|
||||
|
||||
for topic in mqtt:
|
||||
data[domain].append(topic)
|
||||
|
||||
return BASE.format(json.dumps(data, indent=4))
|
||||
|
||||
|
||||
def validate(integrations: Dict[str, Integration], config: Config):
|
||||
"""Validate MQTT file."""
|
||||
mqtt_path = config.root / "homeassistant/generated/mqtt.py"
|
||||
config.cache["mqtt"] = content = generate_and_validate(integrations)
|
||||
|
||||
if config.specific_integrations:
|
||||
return
|
||||
|
||||
with open(str(mqtt_path)) as fp:
|
||||
if fp.read().strip() != content:
|
||||
config.add_error(
|
||||
"mqtt",
|
||||
"File mqtt.py is not up to date. Run python3 -m script.hassfest",
|
||||
fixable=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def generate(integrations: Dict[str, Integration], config: Config):
|
||||
"""Generate MQTT file."""
|
||||
mqtt_path = config.root / "homeassistant/generated/mqtt.py"
|
||||
with open(str(mqtt_path), "w") as fp:
|
||||
fp.write(f"{config.cache['mqtt']}\n")
|
|
@ -4,6 +4,7 @@ import re
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.mqtt.abbreviations import (
|
||||
ABBREVIATIONS,
|
||||
|
@ -13,7 +14,12 @@ from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_st
|
|||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
|
||||
from tests.async_mock import AsyncMock, patch
|
||||
from tests.common import async_fire_mqtt_message, mock_device_registry, mock_registry
|
||||
from tests.common import (
|
||||
async_fire_mqtt_message,
|
||||
mock_device_registry,
|
||||
mock_entity_platform,
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -436,3 +442,75 @@ async def test_complex_discovery_topic_prefix(hass, mqtt_mock, caplog):
|
|||
assert state is not None
|
||||
assert state.name == "Beer"
|
||||
assert ("binary_sensor", "node1 object1") in hass.data[ALREADY_DISCOVERED]
|
||||
|
||||
|
||||
async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||
hass, mqtt_client_mock, mqtt_mock
|
||||
):
|
||||
"""Check MQTT integration discovery subscribe and unsubscribe."""
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
entry = hass.config_entries.async_entries("mqtt")[0]
|
||||
mqtt_mock().connected = True
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.discovery.async_get_mqtt",
|
||||
return_value={"comp": ["comp/discovery/#"]},
|
||||
):
|
||||
await async_start(hass, "homeassistant", entry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
||||
assert not mqtt_client_mock.unsubscribe.called
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
async def async_step_mqtt(self, discovery_info):
|
||||
"""Test mqtt step."""
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
||||
assert not mqtt_client_mock.unsubscribe.called
|
||||
|
||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||
await hass.async_block_till_done()
|
||||
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#")
|
||||
mqtt_client_mock.unsubscribe.reset_mock()
|
||||
|
||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||
await hass.async_block_till_done()
|
||||
assert not mqtt_client_mock.unsubscribe.called
|
||||
|
||||
|
||||
async def test_mqtt_discovery_unsubscribe_once(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Check MQTT integration discovery unsubscribe once."""
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
entry = hass.config_entries.async_entries("mqtt")[0]
|
||||
mqtt_mock().connected = True
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.discovery.async_get_mqtt",
|
||||
return_value={"comp": ["comp/discovery/#"]},
|
||||
):
|
||||
await async_start(hass, "homeassistant", entry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
||||
assert not mqtt_client_mock.unsubscribe.called
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
async def async_step_mqtt(self, discovery_info):
|
||||
"""Test mqtt step."""
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#")
|
||||
|
|
|
@ -1,8 +1,47 @@
|
|||
"""Test config flow."""
|
||||
from homeassistant.components.mqtt.models import Message
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock):
|
||||
"""Check MQTT flow aborts when an entry already exist."""
|
||||
MockConfigEntry(domain="tasmota").add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": "mqtt"}
|
||||
)
|
||||
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
||||
"""Check MQTT flow aborts if discovery topic is invalid."""
|
||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##")
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": "mqtt"}, data=discovery_info
|
||||
)
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "invalid_discovery_info"
|
||||
|
||||
|
||||
async def test_mqtt_setup(hass, mqtt_mock) -> None:
|
||||
"""Test we can finish a config flow through MQTT with custom prefix."""
|
||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#")
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": "mqtt"}, data=discovery_info
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
|
||||
assert result["type"] == "create_entry"
|
||||
assert result["result"].data == {
|
||||
"discovery_prefix": "custom_prefix/123",
|
||||
}
|
||||
|
||||
|
||||
async def test_user_setup(hass, mqtt_mock):
|
||||
"""Test we can finish a config flow."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
|
@ -35,15 +74,32 @@ async def test_user_setup_advanced(hass, mqtt_mock):
|
|||
}
|
||||
|
||||
|
||||
async def test_user_setup_invalid_topic_prefix(hass, mqtt_mock):
|
||||
"""Test if connection cannot be made."""
|
||||
async def test_user_setup_advanced_strip_wildcard(hass, mqtt_mock):
|
||||
"""Test we can finish a config flow."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": "user", "show_advanced_options": True}
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"discovery_prefix": "tasmota/config/#"}
|
||||
result["flow_id"], {"discovery_prefix": "test_tasmota/discovery/#"}
|
||||
)
|
||||
|
||||
assert result["type"] == "create_entry"
|
||||
assert result["result"].data == {
|
||||
"discovery_prefix": "test_tasmota/discovery",
|
||||
}
|
||||
|
||||
|
||||
async def test_user_setup_invalid_topic_prefix(hass, mqtt_mock):
|
||||
"""Test abort on invalid discovery topic."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": "user", "show_advanced_options": True}
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"discovery_prefix": "tasmota/config/##"}
|
||||
)
|
||||
|
||||
assert result["type"] == "form"
|
||||
|
|
|
@ -81,7 +81,7 @@ async def test_user_has_confirmation(hass, discovery_flow_conf):
|
|||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
|
||||
@pytest.mark.parametrize("source", ["discovery", "ssdp", "zeroconf"])
|
||||
@pytest.mark.parametrize("source", ["discovery", "mqtt", "ssdp", "zeroconf"])
|
||||
async def test_discovery_single_instance(hass, discovery_flow_conf, source):
|
||||
"""Test we not allow duplicates."""
|
||||
flow = config_entries.HANDLERS["test"]()
|
||||
|
@ -95,7 +95,7 @@ async def test_discovery_single_instance(hass, discovery_flow_conf, source):
|
|||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("source", ["discovery", "ssdp", "zeroconf"])
|
||||
@pytest.mark.parametrize("source", ["discovery", "mqtt", "ssdp", "zeroconf"])
|
||||
async def test_discovery_confirmation(hass, discovery_flow_conf, source):
|
||||
"""Test we ask for confirmation via discovery."""
|
||||
flow = config_entries.HANDLERS["test"]()
|
||||
|
|
|
@ -181,6 +181,7 @@ def test_integration_properties(hass):
|
|||
},
|
||||
{"manufacturer": "Signify", "modelName": "Philips hue bridge 2015"},
|
||||
],
|
||||
"mqtt": ["hue/discovery"],
|
||||
},
|
||||
)
|
||||
assert integration.name == "Philips Hue"
|
||||
|
@ -198,6 +199,7 @@ def test_integration_properties(hass):
|
|||
},
|
||||
{"manufacturer": "Signify", "modelName": "Philips hue bridge 2015"},
|
||||
]
|
||||
assert integration.mqtt == ["hue/discovery"]
|
||||
assert integration.dependencies == ["test-dep"]
|
||||
assert integration.requirements == ["test-req==1.0.0"]
|
||||
assert integration.is_built_in is True
|
||||
|
@ -217,6 +219,7 @@ def test_integration_properties(hass):
|
|||
assert integration.homekit is None
|
||||
assert integration.zeroconf is None
|
||||
assert integration.ssdp is None
|
||||
assert integration.mqtt is None
|
||||
|
||||
integration = loader.Integration(
|
||||
hass,
|
||||
|
@ -266,6 +269,7 @@ def _get_test_integration(hass, name, config_flow):
|
|||
"zeroconf": [f"_{name}._tcp.local."],
|
||||
"homekit": {"models": [name]},
|
||||
"ssdp": [{"manufacturer": name, "modelName": name}],
|
||||
"mqtt": [f"{name}/discovery"],
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -371,6 +375,21 @@ async def test_get_ssdp(hass):
|
|||
assert ssdp["test_2"] == [{"manufacturer": "test_2", "modelName": "test_2"}]
|
||||
|
||||
|
||||
async def test_get_mqtt(hass):
|
||||
"""Verify that custom components with MQTT are found."""
|
||||
test_1_integration = _get_test_integration(hass, "test_1", True)
|
||||
test_2_integration = _get_test_integration(hass, "test_2", True)
|
||||
|
||||
with patch("homeassistant.loader.async_get_custom_components") as mock_get:
|
||||
mock_get.return_value = {
|
||||
"test_1": test_1_integration,
|
||||
"test_2": test_2_integration,
|
||||
}
|
||||
mqtt = await loader.async_get_mqtt(hass)
|
||||
assert mqtt["test_1"] == ["test_1/discovery"]
|
||||
assert mqtt["test_2"] == ["test_2/discovery"]
|
||||
|
||||
|
||||
async def test_get_custom_components_safe_mode(hass):
|
||||
"""Test that we get empty custom components in safe mode."""
|
||||
hass.config.safe_mode = True
|
||||
|
|
|
@ -187,6 +187,23 @@ async def test_install_on_docker(hass):
|
|||
)
|
||||
|
||||
|
||||
async def test_discovery_requirements_mqtt(hass):
|
||||
"""Test that we load discovery requirements."""
|
||||
hass.config.skip_pip = False
|
||||
mqtt = await loader.async_get_integration(hass, "mqtt")
|
||||
|
||||
mock_integration(
|
||||
hass, MockModule("mqtt_comp", partial_manifest={"mqtt": ["foo/discovery"]})
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.requirements.async_process_requirements",
|
||||
) as mock_process:
|
||||
await async_get_integration_with_requirements(hass, "mqtt_comp")
|
||||
|
||||
assert len(mock_process.mock_calls) == 2 # mqtt also depends on http
|
||||
assert mock_process.mock_calls[0][1][2] == mqtt.requirements
|
||||
|
||||
|
||||
async def test_discovery_requirements_ssdp(hass):
|
||||
"""Test that we load discovery requirements."""
|
||||
hass.config.skip_pip = False
|
||||
|
|
Loading…
Reference in New Issue