diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index b42032c6a1c..8da9a642bf2 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -69,6 +69,7 @@ from .const import ( DEFAULT_QOS, DEFAULT_RETAIN, DEFAULT_WILL, + DOMAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, PROTOCOL_311, @@ -86,8 +87,6 @@ from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) -DOMAIN = "mqtt" - DATA_MQTT = "mqtt" SERVICE_PUBLISH = "publish" diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index c51cf84ef85..3e56ab6caf9 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -41,6 +41,8 @@ DEFAULT_WILL = { ATTR_RETAIN: DEFAULT_RETAIN, } +DOMAIN = "mqtt" + MQTT_CONNECTED = "mqtt_connected" MQTT_DISCONNECTED = "mqtt_disconnected" diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 6c4cbfd212f..7a478733826 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -1,5 +1,6 @@ """Support for MQTT discovery.""" import asyncio +import functools import json import logging import re @@ -9,9 +10,15 @@ from homeassistant.components import mqtt from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import HomeAssistantType +from homeassistant.loader import async_get_mqtt from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS -from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC +from .const import ( + ATTR_DISCOVERY_HASH, + ATTR_DISCOVERY_PAYLOAD, + ATTR_DISCOVERY_TOPIC, + DOMAIN, +) _LOGGER = logging.getLogger(__name__) @@ -39,7 +46,9 @@ SUPPORTED_COMPONENTS = [ ALREADY_DISCOVERED = "mqtt_discovered_components" CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup" DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock" +DATA_CONFIG_FLOW_LOCK = "mqtt_discovery_config_flow_lock" DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe" +INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" LAST_DISCOVERY = "mqtt_last_discovery" @@ -65,8 +74,9 @@ async def async_start( hass: HomeAssistantType, discovery_topic, config_entry=None ) -> bool: """Start MQTT Discovery.""" + mqtt_integrations = {} - async def async_device_message_received(msg): + async def async_entity_message_received(msg): """Process the received message.""" hass.data[LAST_DISCOVERY] = time.time() payload = msg.payload @@ -172,12 +182,52 @@ async def async_start( ) hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock() + hass.data[DATA_CONFIG_FLOW_LOCK] = asyncio.Lock() hass.data[CONFIG_ENTRY_IS_SETUP] = set() hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe( - hass, f"{discovery_topic}/#", async_device_message_received, 0 + hass, f"{discovery_topic}/#", async_entity_message_received, 0 ) hass.data[LAST_DISCOVERY] = time.time() + mqtt_integrations = await async_get_mqtt(hass) + + hass.data[INTEGRATION_UNSUBSCRIBE] = {} + + for (integration, topics) in mqtt_integrations.items(): + + async def async_integration_message_received(integration, msg): + """Process the received message.""" + key = f"{integration}_{msg.subscribed_topic}" + + # Lock to prevent initiating many parallel config flows. + # Note: The lock is not intended to prevent a race, only for performance + async with hass.data[DATA_CONFIG_FLOW_LOCK]: + # Already unsubscribed + if key not in hass.data[INTEGRATION_UNSUBSCRIBE]: + return + + result = await hass.config_entries.flow.async_init( + integration, context={"source": DOMAIN}, data=msg + ) + if ( + result + and result["type"] == "abort" + and result["reason"] + in ["already_configured", "single_instance_allowed"] + ): + unsub = hass.data[INTEGRATION_UNSUBSCRIBE].pop(key, None) + if unsub is None: + return + unsub() + + for topic in topics: + key = f"{integration}_{topic}" + hass.data[INTEGRATION_UNSUBSCRIBE][key] = await mqtt.async_subscribe( + hass, + topic, + functools.partial(async_integration_message_received, integration), + 0, + ) return True @@ -187,3 +237,7 @@ async def async_stop(hass: HomeAssistantType) -> bool: if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]: hass.data[DISCOVERY_UNSUBSCRIBE]() hass.data[DISCOVERY_UNSUBSCRIBE] = None + if INTEGRATION_UNSUBSCRIBE in hass.data: + for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()): + unsub() + hass.data[INTEGRATION_UNSUBSCRIBE].pop(key) diff --git a/homeassistant/components/tasmota/config_flow.py b/homeassistant/components/tasmota/config_flow.py index fbac4bd7dd2..397e735ae5d 100644 --- a/homeassistant/components/tasmota/config_flow.py +++ b/homeassistant/components/tasmota/config_flow.py @@ -21,36 +21,72 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH + def __init__(self): + """Initialize flow.""" + self._prefix = DEFAULT_PREFIX + + async def async_step_mqtt(self, discovery_info=None): + """Handle a flow initialized by MQTT discovery.""" + if self._async_in_progress() or self._async_current_entries(): + return self.async_abort(reason="single_instance_allowed") + + await self.async_set_unique_id(DOMAIN) + + # Validate the topic, will throw if it fails + prefix = discovery_info.subscribed_topic + if prefix.endswith("/#"): + prefix = prefix[:-2] + try: + valid_subscribe_topic(f"{prefix}/#") + except vol.Invalid: + return self.async_abort(reason="invalid_discovery_info") + + self._prefix = prefix + + return await self.async_step_confirm() + async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" if self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") - return await self.async_step_config() + if self.show_advanced_options: + return await self.async_step_config() + return await self.async_step_confirm() async def async_step_config(self, user_input=None): """Confirm the setup.""" errors = {} - data = {CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX} + data = {CONF_DISCOVERY_PREFIX: self._prefix} if user_input is not None: bad_prefix = False - if self.show_advanced_options: - prefix = user_input[CONF_DISCOVERY_PREFIX] - try: - valid_subscribe_topic(f"{prefix}/#") - except vol.Invalid: - errors["base"] = "invalid_discovery_topic" - bad_prefix = True - else: - data = user_input + prefix = user_input[CONF_DISCOVERY_PREFIX] + if prefix.endswith("/#"): + prefix = prefix[:-2] + try: + valid_subscribe_topic(f"{prefix}/#") + except vol.Invalid: + errors["base"] = "invalid_discovery_topic" + bad_prefix = True + else: + data[CONF_DISCOVERY_PREFIX] = prefix if not bad_prefix: return self.async_create_entry(title="Tasmota", data=data) fields = {} - if self.show_advanced_options: - fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX)] = str + fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=self._prefix)] = str return self.async_show_form( step_id="config", data_schema=vol.Schema(fields), errors=errors ) + + async def async_step_confirm(self, user_input=None): + """Confirm the setup.""" + + data = {CONF_DISCOVERY_PREFIX: self._prefix} + + if user_input is not None: + return self.async_create_entry(title="Tasmota", data=data) + + return self.async_show_form(step_id="confirm") diff --git a/homeassistant/components/tasmota/manifest.json b/homeassistant/components/tasmota/manifest.json index 5540988edcc..58c40209da5 100644 --- a/homeassistant/components/tasmota/manifest.json +++ b/homeassistant/components/tasmota/manifest.json @@ -5,5 +5,6 @@ "documentation": "https://www.home-assistant.io/integrations/tasmota", "requirements": ["hatasmota==0.0.10"], "dependencies": ["mqtt"], + "mqtt": ["tasmota/discovery/#"], "codeowners": ["@emontnemery"] } diff --git a/homeassistant/components/tasmota/strings.json b/homeassistant/components/tasmota/strings.json index d19bb093263..3d32b54b95d 100644 --- a/homeassistant/components/tasmota/strings.json +++ b/homeassistant/components/tasmota/strings.json @@ -1,6 +1,9 @@ { "config": { "step": { + "confirm": { + "description": "Do you want to set up Tasmota?" + }, "config": { "title": "Tasmota", "description": "Please enter the Tasmota configuration.", diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 38593badf2f..84676501990 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -25,6 +25,7 @@ SOURCE_HASSIO = "hassio" SOURCE_HOMEKIT = "homekit" SOURCE_IMPORT = "import" SOURCE_INTEGRATION_DISCOVERY = "integration_discovery" +SOURCE_MQTT = "mqtt" SOURCE_SSDP = "ssdp" SOURCE_USER = "user" SOURCE_ZEROCONF = "zeroconf" diff --git a/homeassistant/generated/mqtt.py b/homeassistant/generated/mqtt.py new file mode 100644 index 00000000000..41aac3e3a08 --- /dev/null +++ b/homeassistant/generated/mqtt.py @@ -0,0 +1,12 @@ +"""Automatically generated by hassfest. + +To update, run python3 -m script.hassfest +""" + +# fmt: off + +MQTT = { + "tasmota": [ + "tasmota/discovery/#" + ] +} diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index f957d884d8d..6b9df47c4d8 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -80,6 +80,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): async_step_zeroconf = async_step_discovery async_step_ssdp = async_step_discovery + async_step_mqtt = async_step_discovery async_step_homekit = async_step_discovery async def async_step_import(self, _: Optional[Dict[str, Any]]) -> Dict[str, Any]: diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index 55ec3984b82..ace1365df1b 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -302,6 +302,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): return await self.async_step_pick_implementation() async_step_user = async_step_pick_implementation + async_step_mqtt = async_step_discovery async_step_ssdp = async_step_discovery async_step_zeroconf = async_step_discovery async_step_homekit = async_step_discovery diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 53f793678c0..4538e6f3669 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -25,6 +25,7 @@ from typing import ( cast, ) +from homeassistant.generated.mqtt import MQTT from homeassistant.generated.ssdp import SSDP from homeassistant.generated.zeroconf import HOMEKIT, ZEROCONF @@ -202,6 +203,21 @@ async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List]: return ssdp +async def async_get_mqtt(hass: "HomeAssistant") -> Dict[str, List]: + """Return cached list of MQTT mappings.""" + + mqtt: Dict[str, List] = MQTT.copy() + + integrations = await async_get_custom_components(hass) + for integration in integrations.values(): + if not integration.mqtt: + continue + + mqtt[integration.domain] = integration.mqtt + + return mqtt + + class Integration: """An integration in Home Assistant.""" @@ -323,6 +339,11 @@ class Integration: """Return Integration Quality Scale.""" return cast(str, self.manifest.get("quality_scale")) + @property + def mqtt(self) -> Optional[list]: + """Return Integration MQTT entries.""" + return cast(List[dict], self.manifest.get("mqtt")) + @property def ssdp(self) -> Optional[list]: """Return Integration SSDP entries.""" diff --git a/homeassistant/requirements.py b/homeassistant/requirements.py index 303f6219cae..b490db20b8d 100644 --- a/homeassistant/requirements.py +++ b/homeassistant/requirements.py @@ -15,6 +15,7 @@ DATA_INTEGRATIONS_WITH_REQS = "integrations_with_reqs" CONSTRAINT_FILE = "package_constraints.txt" _LOGGER = logging.getLogger(__name__) DISCOVERY_INTEGRATIONS: Dict[str, Iterable[str]] = { + "mqtt": ("mqtt",), "ssdp": ("ssdp",), "zeroconf": ("zeroconf", "homekit"), } diff --git a/script/hassfest/__main__.py b/script/hassfest/__main__.py index e3e4fbf38c6..4b2e91524e2 100644 --- a/script/hassfest/__main__.py +++ b/script/hassfest/__main__.py @@ -11,6 +11,7 @@ from . import ( dependencies, json, manifest, + mqtt, requirements, services, ssdp, @@ -25,6 +26,7 @@ INTEGRATION_PLUGINS = [ config_flow, dependencies, manifest, + mqtt, services, ssdp, translations, diff --git a/script/hassfest/config_flow.py b/script/hassfest/config_flow.py index b51e3c43afb..d3402c3dc9a 100644 --- a/script/hassfest/config_flow.py +++ b/script/hassfest/config_flow.py @@ -33,6 +33,11 @@ def validate_integration(config: Config, integration: Integration): "config_flow", "HomeKit information in a manifest requires a config flow to exist", ) + if integration.manifest.get("mqtt"): + integration.add_error( + "config_flow", + "MQTT information in a manifest requires a config flow to exist", + ) if integration.manifest.get("ssdp"): integration.add_error( "config_flow", @@ -51,6 +56,7 @@ def validate_integration(config: Config, integration: Integration): "async_step_discovery" in config_flow or "async_step_hassio" in config_flow or "async_step_homekit" in config_flow + or "async_step_mqtt" in config_flow or "async_step_ssdp" in config_flow or "async_step_zeroconf" in config_flow ) @@ -91,6 +97,7 @@ def generate_and_validate(integrations: Dict[str, Integration], config: Config): if not ( integration.manifest.get("config_flow") or integration.manifest.get("homekit") + or integration.manifest.get("mqtt") or integration.manifest.get("ssdp") or integration.manifest.get("zeroconf") ): diff --git a/script/hassfest/manifest.py b/script/hassfest/manifest.py index b0148b0911a..389e380af85 100644 --- a/script/hassfest/manifest.py +++ b/script/hassfest/manifest.py @@ -38,6 +38,7 @@ MANIFEST_SCHEMA = vol.Schema( vol.Required("domain"): str, vol.Required("name"): str, vol.Optional("config_flow"): bool, + vol.Optional("mqtt"): [str], vol.Optional("zeroconf"): [ vol.Any( str, diff --git a/script/hassfest/mqtt.py b/script/hassfest/mqtt.py new file mode 100644 index 00000000000..fdc16895d8c --- /dev/null +++ b/script/hassfest/mqtt.py @@ -0,0 +1,64 @@ +"""Generate MQTT file.""" +from collections import defaultdict +import json +from typing import Dict + +from .model import Config, Integration + +BASE = """ +\"\"\"Automatically generated by hassfest. + +To update, run python3 -m script.hassfest +\"\"\" + +# fmt: off + +MQTT = {} +""".strip() + + +def generate_and_validate(integrations: Dict[str, Integration]): + """Validate and generate MQTT data.""" + + data = defaultdict(list) + + for domain in sorted(integrations): + integration = integrations[domain] + + if not integration.manifest: + continue + + mqtt = integration.manifest.get("mqtt") + + if not mqtt: + continue + + for topic in mqtt: + data[domain].append(topic) + + return BASE.format(json.dumps(data, indent=4)) + + +def validate(integrations: Dict[str, Integration], config: Config): + """Validate MQTT file.""" + mqtt_path = config.root / "homeassistant/generated/mqtt.py" + config.cache["mqtt"] = content = generate_and_validate(integrations) + + if config.specific_integrations: + return + + with open(str(mqtt_path)) as fp: + if fp.read().strip() != content: + config.add_error( + "mqtt", + "File mqtt.py is not up to date. Run python3 -m script.hassfest", + fixable=True, + ) + return + + +def generate(integrations: Dict[str, Integration], config: Config): + """Generate MQTT file.""" + mqtt_path = config.root / "homeassistant/generated/mqtt.py" + with open(str(mqtt_path), "w") as fp: + fp.write(f"{config.cache['mqtt']}\n") diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 42229691a0f..e1365418483 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -4,6 +4,7 @@ import re import pytest +from homeassistant import config_entries from homeassistant.components import mqtt from homeassistant.components.mqtt.abbreviations import ( ABBREVIATIONS, @@ -13,7 +14,12 @@ from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_st from homeassistant.const import STATE_OFF, STATE_ON from tests.async_mock import AsyncMock, patch -from tests.common import async_fire_mqtt_message, mock_device_registry, mock_registry +from tests.common import ( + async_fire_mqtt_message, + mock_device_registry, + mock_entity_platform, + mock_registry, +) @pytest.fixture @@ -436,3 +442,75 @@ async def test_complex_discovery_topic_prefix(hass, mqtt_mock, caplog): assert state is not None assert state.name == "Beer" assert ("binary_sensor", "node1 object1") in hass.data[ALREADY_DISCOVERED] + + +async def test_mqtt_integration_discovery_subscribe_unsubscribe( + hass, mqtt_client_mock, mqtt_mock +): + """Check MQTT integration discovery subscribe and unsubscribe.""" + mock_entity_platform(hass, "config_flow.comp", None) + + entry = hass.config_entries.async_entries("mqtt")[0] + mqtt_mock().connected = True + + with patch( + "homeassistant.components.mqtt.discovery.async_get_mqtt", + return_value={"comp": ["comp/discovery/#"]}, + ): + await async_start(hass, "homeassistant", entry) + await hass.async_block_till_done() + + mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert not mqtt_client_mock.unsubscribe.called + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + async def async_step_mqtt(self, discovery_info): + """Test mqtt step.""" + return self.async_abort(reason="already_configured") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert not mqtt_client_mock.unsubscribe.called + + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await hass.async_block_till_done() + mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") + mqtt_client_mock.unsubscribe.reset_mock() + + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await hass.async_block_till_done() + assert not mqtt_client_mock.unsubscribe.called + + +async def test_mqtt_discovery_unsubscribe_once(hass, mqtt_client_mock, mqtt_mock): + """Check MQTT integration discovery unsubscribe once.""" + mock_entity_platform(hass, "config_flow.comp", None) + + entry = hass.config_entries.async_entries("mqtt")[0] + mqtt_mock().connected = True + + with patch( + "homeassistant.components.mqtt.discovery.async_get_mqtt", + return_value={"comp": ["comp/discovery/#"]}, + ): + await async_start(hass, "homeassistant", entry) + await hass.async_block_till_done() + + mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert not mqtt_client_mock.unsubscribe.called + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + async def async_step_mqtt(self, discovery_info): + """Test mqtt step.""" + return self.async_abort(reason="already_configured") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await hass.async_block_till_done() + await hass.async_block_till_done() + mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") diff --git a/tests/components/tasmota/test_config_flow.py b/tests/components/tasmota/test_config_flow.py index ee66b486a5f..469e5e29812 100644 --- a/tests/components/tasmota/test_config_flow.py +++ b/tests/components/tasmota/test_config_flow.py @@ -1,8 +1,47 @@ """Test config flow.""" +from homeassistant.components.mqtt.models import Message from tests.common import MockConfigEntry +async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock): + """Check MQTT flow aborts when an entry already exist.""" + MockConfigEntry(domain="tasmota").add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + "tasmota", context={"source": "mqtt"} + ) + + assert result["type"] == "abort" + assert result["reason"] == "single_instance_allowed" + + +async def test_mqtt_abort_invalid_topic(hass, mqtt_mock): + """Check MQTT flow aborts if discovery topic is invalid.""" + discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##") + result = await hass.config_entries.flow.async_init( + "tasmota", context={"source": "mqtt"}, data=discovery_info + ) + assert result["type"] == "abort" + assert result["reason"] == "invalid_discovery_info" + + +async def test_mqtt_setup(hass, mqtt_mock) -> None: + """Test we can finish a config flow through MQTT with custom prefix.""" + discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#") + result = await hass.config_entries.flow.async_init( + "tasmota", context={"source": "mqtt"}, data=discovery_info + ) + assert result["type"] == "form" + + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + + assert result["type"] == "create_entry" + assert result["result"].data == { + "discovery_prefix": "custom_prefix/123", + } + + async def test_user_setup(hass, mqtt_mock): """Test we can finish a config flow.""" result = await hass.config_entries.flow.async_init( @@ -35,15 +74,32 @@ async def test_user_setup_advanced(hass, mqtt_mock): } -async def test_user_setup_invalid_topic_prefix(hass, mqtt_mock): - """Test if connection cannot be made.""" +async def test_user_setup_advanced_strip_wildcard(hass, mqtt_mock): + """Test we can finish a config flow.""" result = await hass.config_entries.flow.async_init( "tasmota", context={"source": "user", "show_advanced_options": True} ) assert result["type"] == "form" result = await hass.config_entries.flow.async_configure( - result["flow_id"], {"discovery_prefix": "tasmota/config/#"} + result["flow_id"], {"discovery_prefix": "test_tasmota/discovery/#"} + ) + + assert result["type"] == "create_entry" + assert result["result"].data == { + "discovery_prefix": "test_tasmota/discovery", + } + + +async def test_user_setup_invalid_topic_prefix(hass, mqtt_mock): + """Test abort on invalid discovery topic.""" + result = await hass.config_entries.flow.async_init( + "tasmota", context={"source": "user", "show_advanced_options": True} + ) + assert result["type"] == "form" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"discovery_prefix": "tasmota/config/##"} ) assert result["type"] == "form" diff --git a/tests/helpers/test_config_entry_flow.py b/tests/helpers/test_config_entry_flow.py index 2bb993f1197..926aa98e308 100644 --- a/tests/helpers/test_config_entry_flow.py +++ b/tests/helpers/test_config_entry_flow.py @@ -81,7 +81,7 @@ async def test_user_has_confirmation(hass, discovery_flow_conf): assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY -@pytest.mark.parametrize("source", ["discovery", "ssdp", "zeroconf"]) +@pytest.mark.parametrize("source", ["discovery", "mqtt", "ssdp", "zeroconf"]) async def test_discovery_single_instance(hass, discovery_flow_conf, source): """Test we not allow duplicates.""" flow = config_entries.HANDLERS["test"]() @@ -95,7 +95,7 @@ async def test_discovery_single_instance(hass, discovery_flow_conf, source): assert result["reason"] == "single_instance_allowed" -@pytest.mark.parametrize("source", ["discovery", "ssdp", "zeroconf"]) +@pytest.mark.parametrize("source", ["discovery", "mqtt", "ssdp", "zeroconf"]) async def test_discovery_confirmation(hass, discovery_flow_conf, source): """Test we ask for confirmation via discovery.""" flow = config_entries.HANDLERS["test"]() diff --git a/tests/test_loader.py b/tests/test_loader.py index f5ba54ff269..71a373a579d 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -181,6 +181,7 @@ def test_integration_properties(hass): }, {"manufacturer": "Signify", "modelName": "Philips hue bridge 2015"}, ], + "mqtt": ["hue/discovery"], }, ) assert integration.name == "Philips Hue" @@ -198,6 +199,7 @@ def test_integration_properties(hass): }, {"manufacturer": "Signify", "modelName": "Philips hue bridge 2015"}, ] + assert integration.mqtt == ["hue/discovery"] assert integration.dependencies == ["test-dep"] assert integration.requirements == ["test-req==1.0.0"] assert integration.is_built_in is True @@ -217,6 +219,7 @@ def test_integration_properties(hass): assert integration.homekit is None assert integration.zeroconf is None assert integration.ssdp is None + assert integration.mqtt is None integration = loader.Integration( hass, @@ -266,6 +269,7 @@ def _get_test_integration(hass, name, config_flow): "zeroconf": [f"_{name}._tcp.local."], "homekit": {"models": [name]}, "ssdp": [{"manufacturer": name, "modelName": name}], + "mqtt": [f"{name}/discovery"], }, ) @@ -371,6 +375,21 @@ async def test_get_ssdp(hass): assert ssdp["test_2"] == [{"manufacturer": "test_2", "modelName": "test_2"}] +async def test_get_mqtt(hass): + """Verify that custom components with MQTT are found.""" + test_1_integration = _get_test_integration(hass, "test_1", True) + test_2_integration = _get_test_integration(hass, "test_2", True) + + with patch("homeassistant.loader.async_get_custom_components") as mock_get: + mock_get.return_value = { + "test_1": test_1_integration, + "test_2": test_2_integration, + } + mqtt = await loader.async_get_mqtt(hass) + assert mqtt["test_1"] == ["test_1/discovery"] + assert mqtt["test_2"] == ["test_2/discovery"] + + async def test_get_custom_components_safe_mode(hass): """Test that we get empty custom components in safe mode.""" hass.config.safe_mode = True diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 6297da0c2d5..c0a1f0723ac 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -187,6 +187,23 @@ async def test_install_on_docker(hass): ) +async def test_discovery_requirements_mqtt(hass): + """Test that we load discovery requirements.""" + hass.config.skip_pip = False + mqtt = await loader.async_get_integration(hass, "mqtt") + + mock_integration( + hass, MockModule("mqtt_comp", partial_manifest={"mqtt": ["foo/discovery"]}) + ) + with patch( + "homeassistant.requirements.async_process_requirements", + ) as mock_process: + await async_get_integration_with_requirements(hass, "mqtt_comp") + + assert len(mock_process.mock_calls) == 2 # mqtt also depends on http + assert mock_process.mock_calls[0][1][2] == mqtt.requirements + + async def test_discovery_requirements_ssdp(hass): """Test that we load discovery requirements.""" hass.config.skip_pip = False