From 4bad88b42c093f3992e241326ec363bc0ba5538e Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Wed, 31 Jan 2024 13:17:00 +0100 Subject: [PATCH] Update Ecovacs config_flow to support self-hosted instances (#108944) * Update Ecovacs config_flow to support self-hosted instances * Selfhosted should add their instance urls * Improve config flow * Improve and adapt to version bump * Add test for self-hosted * Make ruff happy * Update homeassistant/components/ecovacs/strings.json Co-authored-by: Joost Lekkerkerker * Implement suggestions * Apply suggestions from code review Co-authored-by: Martin Hjelmare * Implement suggestions * Remove , --------- Co-authored-by: Joost Lekkerkerker Co-authored-by: Martin Hjelmare --- .../components/ecovacs/config_flow.py | 165 +++++++-- homeassistant/components/ecovacs/const.py | 12 + .../components/ecovacs/controller.py | 17 + .../components/ecovacs/diagnostics.py | 10 +- homeassistant/components/ecovacs/strings.json | 30 +- tests/components/ecovacs/conftest.py | 25 +- tests/components/ecovacs/const.py | 23 +- .../ecovacs/snapshots/test_diagnostics.ambr | 53 ++- tests/components/ecovacs/test_config_flow.py | 322 ++++++++++++++++-- tests/components/ecovacs/test_diagnostics.py | 9 + tests/components/ecovacs/test_init.py | 2 + 11 files changed, 596 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/ecovacs/config_flow.py b/homeassistant/components/ecovacs/config_flow.py index 7b56417f93e..39c61b3ce23 100644 --- a/homeassistant/components/ecovacs/config_flow.py +++ b/homeassistant/components/ecovacs/config_flow.py @@ -2,39 +2,81 @@ from __future__ import annotations import logging +import ssl from typing import Any, cast +from urllib.parse import urlparse from aiohttp import ClientError from deebot_client.authentication import Authenticator, create_rest_config -from deebot_client.exceptions import InvalidAuthenticationError +from deebot_client.const import UNDEFINED, UndefinedType +from deebot_client.exceptions import InvalidAuthenticationError, MqttError +from deebot_client.mqtt_client import MqttClient, create_mqtt_config from deebot_client.util import md5 from deebot_client.util.continents import COUNTRIES_TO_CONTINENTS, get_continent import voluptuous as vol from homeassistant.config_entries import ConfigFlow -from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME +from homeassistant.const import CONF_COUNTRY, CONF_MODE, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.helpers import aiohttp_client, selector from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.loader import async_get_issue_tracker +from homeassistant.util.ssl import get_default_no_verify_context -from .const import CONF_CONTINENT, DOMAIN +from .const import ( + CONF_CONTINENT, + CONF_OVERRIDE_MQTT_URL, + CONF_OVERRIDE_REST_URL, + CONF_VERIFY_MQTT_CERTIFICATE, + DOMAIN, + InstanceMode, +) from .util import get_client_device_id _LOGGER = logging.getLogger(__name__) +def _validate_url( + value: str, + field_name: str, + schema_list: set[str], +) -> dict[str, str]: + """Validate an URL and return error dictionary.""" + if urlparse(value).scheme not in schema_list: + return {field_name: f"invalid_url_schema_{field_name}"} + try: + vol.Schema(vol.Url())(value) + except vol.Invalid: + return {field_name: "invalid_url"} + return {} + + async def _validate_input( hass: HomeAssistant, user_input: dict[str, Any] ) -> dict[str, str]: """Validate user input.""" errors: dict[str, str] = {} + if rest_url := user_input.get(CONF_OVERRIDE_REST_URL): + errors.update( + _validate_url(rest_url, CONF_OVERRIDE_REST_URL, {"http", "https"}) + ) + if mqtt_url := user_input.get(CONF_OVERRIDE_MQTT_URL): + errors.update( + _validate_url(mqtt_url, CONF_OVERRIDE_MQTT_URL, {"mqtt", "mqtts"}) + ) + + if errors: + return errors + + device_id = get_client_device_id() + country = user_input[CONF_COUNTRY] rest_config = create_rest_config( aiohttp_client.async_get_clientsession(hass), - device_id=get_client_device_id(), - country=user_input[CONF_COUNTRY], + device_id=device_id, + country=country, + override_rest_url=rest_url, ) authenticator = Authenticator( @@ -54,6 +96,34 @@ async def _validate_input( _LOGGER.exception("Unexpected exception during login") errors["base"] = "unknown" + if errors: + return errors + + ssl_context: UndefinedType | ssl.SSLContext = UNDEFINED + if not user_input.get(CONF_VERIFY_MQTT_CERTIFICATE, True) and mqtt_url: + ssl_context = get_default_no_verify_context() + + mqtt_config = create_mqtt_config( + device_id=device_id, + country=country, + override_mqtt_url=mqtt_url, + ssl_context=ssl_context, + ) + + client = MqttClient(mqtt_config, authenticator) + cannot_connect_field = CONF_OVERRIDE_MQTT_URL if mqtt_url else "base" + + try: + await client.verify_config() + except MqttError: + _LOGGER.debug("Cannot connect", exc_info=True) + errors[cannot_connect_field] = "cannot_connect" + except InvalidAuthenticationError: + errors["base"] = "invalid_auth" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unexpected exception during mqtt connection verification") + errors["base"] = "unknown" + return errors @@ -62,10 +132,42 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 + _mode: InstanceMode = InstanceMode.CLOUD + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle the initial step.""" + + if not self.show_advanced_options: + return await self.async_step_auth() + + if user_input: + self._mode = user_input[CONF_MODE] + return await self.async_step_auth() + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required( + CONF_MODE, default=InstanceMode.CLOUD + ): selector.SelectSelector( + selector.SelectSelectorConfig( + options=list(InstanceMode), + translation_key="installation_mode", + mode=selector.SelectSelectorMode.DROPDOWN, + ) + ) + } + ), + last_step=False, + ) + + async def async_step_auth( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the auth step.""" errors = {} if user_input: @@ -78,30 +180,41 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN): title=user_input[CONF_USERNAME], data=user_input ) + schema = { + vol.Required(CONF_USERNAME): selector.TextSelector( + selector.TextSelectorConfig(type=selector.TextSelectorType.TEXT) + ), + vol.Required(CONF_PASSWORD): selector.TextSelector( + selector.TextSelectorConfig(type=selector.TextSelectorType.PASSWORD) + ), + vol.Required(CONF_COUNTRY): selector.CountrySelector(), + } + if self._mode == InstanceMode.SELF_HOSTED: + schema.update( + { + vol.Required(CONF_OVERRIDE_REST_URL): selector.TextSelector( + selector.TextSelectorConfig(type=selector.TextSelectorType.URL) + ), + vol.Required(CONF_OVERRIDE_MQTT_URL): selector.TextSelector( + selector.TextSelectorConfig(type=selector.TextSelectorType.URL) + ), + } + ) + if errors: + schema[vol.Optional(CONF_VERIFY_MQTT_CERTIFICATE, default=True)] = bool + + if not user_input: + user_input = { + CONF_COUNTRY: self.hass.config.country, + } + return self.async_show_form( - step_id="user", + step_id="auth", data_schema=self.add_suggested_values_to_schema( - data_schema=vol.Schema( - { - vol.Required(CONF_USERNAME): selector.TextSelector( - selector.TextSelectorConfig( - type=selector.TextSelectorType.TEXT - ) - ), - vol.Required(CONF_PASSWORD): selector.TextSelector( - selector.TextSelectorConfig( - type=selector.TextSelectorType.PASSWORD - ) - ), - vol.Required(CONF_COUNTRY): selector.CountrySelector(), - } - ), - suggested_values=user_input - or { - CONF_COUNTRY: self.hass.config.country, - }, + data_schema=vol.Schema(schema), suggested_values=user_input ), errors=errors, + last_step=True, ) async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult: @@ -181,7 +294,7 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN): # Remove the continent from the user input as it is not needed anymore user_input.pop(CONF_CONTINENT) try: - result = await self.async_step_user(user_input) + result = await self.async_step_auth(user_input) except AbortFlow as ex: if ex.reason == "already_configured": create_repair() diff --git a/homeassistant/components/ecovacs/const.py b/homeassistant/components/ecovacs/const.py index 5edbe11c265..dc055cee519 100644 --- a/homeassistant/components/ecovacs/const.py +++ b/homeassistant/components/ecovacs/const.py @@ -1,12 +1,24 @@ """Ecovacs constants.""" +from enum import StrEnum + from deebot_client.events import LifeSpan DOMAIN = "ecovacs" CONF_CONTINENT = "continent" +CONF_OVERRIDE_REST_URL = "override_rest_url" +CONF_OVERRIDE_MQTT_URL = "override_mqtt_url" +CONF_VERIFY_MQTT_CERTIFICATE = "verify_mqtt_certificate" SUPPORTED_LIFESPANS = ( LifeSpan.BRUSH, LifeSpan.FILTER, LifeSpan.SIDE_BRUSH, ) + + +class InstanceMode(StrEnum): + """Instance mode.""" + + CLOUD = "cloud" + SELF_HOSTED = "self_hosted" diff --git a/homeassistant/components/ecovacs/controller.py b/homeassistant/components/ecovacs/controller.py index e0c3497c178..06e3a1acccd 100644 --- a/homeassistant/components/ecovacs/controller.py +++ b/homeassistant/components/ecovacs/controller.py @@ -3,10 +3,12 @@ from __future__ import annotations from collections.abc import Mapping import logging +import ssl from typing import Any from deebot_client.api_client import ApiClient from deebot_client.authentication import Authenticator, create_rest_config +from deebot_client.const import UNDEFINED, UndefinedType from deebot_client.device import Device from deebot_client.exceptions import DeebotError, InvalidAuthenticationError from deebot_client.models import DeviceInfo @@ -19,7 +21,13 @@ from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady from homeassistant.helpers import aiohttp_client +from homeassistant.util.ssl import get_default_no_verify_context +from .const import ( + CONF_OVERRIDE_MQTT_URL, + CONF_OVERRIDE_REST_URL, + CONF_VERIFY_MQTT_CERTIFICATE, +) from .util import get_client_device_id _LOGGER = logging.getLogger(__name__) @@ -42,15 +50,24 @@ class EcovacsController: aiohttp_client.async_get_clientsession(self._hass), device_id=self._device_id, country=country, + override_rest_url=config.get(CONF_OVERRIDE_REST_URL), ), config[CONF_USERNAME], md5(config[CONF_PASSWORD]), ) self._api_client = ApiClient(self._authenticator) + + mqtt_url = config.get(CONF_OVERRIDE_MQTT_URL) + ssl_context: UndefinedType | ssl.SSLContext = UNDEFINED + if not config.get(CONF_VERIFY_MQTT_CERTIFICATE, True) and mqtt_url: + ssl_context = get_default_no_verify_context() + self._mqtt = MqttClient( create_mqtt_config( device_id=self._device_id, country=country, + override_mqtt_url=mqtt_url, + ssl_context=ssl_context, ), self._authenticator, ) diff --git a/homeassistant/components/ecovacs/diagnostics.py b/homeassistant/components/ecovacs/diagnostics.py index fa7d85ed52a..d961e231631 100644 --- a/homeassistant/components/ecovacs/diagnostics.py +++ b/homeassistant/components/ecovacs/diagnostics.py @@ -8,10 +8,16 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant -from .const import DOMAIN +from .const import CONF_OVERRIDE_MQTT_URL, CONF_OVERRIDE_REST_URL, DOMAIN from .controller import EcovacsController -REDACT_CONFIG = {CONF_USERNAME, CONF_PASSWORD, "title"} +REDACT_CONFIG = { + CONF_USERNAME, + CONF_PASSWORD, + "title", + CONF_OVERRIDE_MQTT_URL, + CONF_OVERRIDE_REST_URL, +} REDACT_DEVICE = {"did", CONF_NAME, "homeId"} diff --git a/homeassistant/components/ecovacs/strings.json b/homeassistant/components/ecovacs/strings.json index 56e3ec1f866..d15e8a67062 100644 --- a/homeassistant/components/ecovacs/strings.json +++ b/homeassistant/components/ecovacs/strings.json @@ -6,14 +6,32 @@ "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", + "invalid_url": "Invalid URL", + "invalid_url_schema_override_rest_url": "Invalid REST URL scheme.\nThe URL should start with `http://` or `https://`.", + "invalid_url_schema_override_mqtt_url": "Invalid MQTT URL scheme.\nThe URL should start with `mqtt://` or `mqtts://`.", "unknown": "[%key:common::config_flow::error::unknown%]" }, "step": { - "user": { + "auth": { "data": { "country": "Country", + "override_rest_url": "REST URL", + "override_mqtt_url": "MQTT URL", "password": "[%key:common::config_flow::data::password%]", - "username": "[%key:common::config_flow::data::username%]" + "username": "[%key:common::config_flow::data::username%]", + "verify_mqtt_certificate": "Verify MQTT SSL certificate" + }, + "data_description": { + "override_rest_url": "Enter the REST URL of your self-hosted instance including the scheme (http/https).", + "override_mqtt_url": "Enter the MQTT URL of your self-hosted instance including the scheme (mqtt/mqtts)." + } + }, + "user": { + "data": { + "mode": "[%key:common::config_flow::data::mode%]" + }, + "data_description": { + "mode": "Select the mode you want to use to connect to Ecovacs. If you are unsure, select 'Cloud'.\n\nSelect 'Self-hosted' only if you have a working self-hosted instance." } } } @@ -157,5 +175,13 @@ "title": "The Ecovacs YAML configuration import failed", "description": "Configuring Ecovacs using YAML is being removed but there is an unexpected continent specified in the YAML configuration.\n\nFrom the given country, the continent '{continent}' is expected. Change the continent and restart Home Assistant to try again or remove the Ecovacs YAML configuration from your configuration.yaml file and continue to [set up the integration]({url}) manually.\n\nIf the contintent '{continent}' is not applicable, please open an issue on [GitHub]({github_issue_url})." } + }, + "selector": { + "installation_mode": { + "options": { + "cloud": "Cloud", + "self_hosted": "Self-hosted" + } + } } } diff --git a/tests/components/ecovacs/conftest.py b/tests/components/ecovacs/conftest.py index 74e4d30a16d..d0f0668cc8c 100644 --- a/tests/components/ecovacs/conftest.py +++ b/tests/components/ecovacs/conftest.py @@ -12,10 +12,10 @@ import pytest from homeassistant.components.ecovacs import PLATFORMS from homeassistant.components.ecovacs.const import DOMAIN from homeassistant.components.ecovacs.controller import EcovacsController -from homeassistant.const import Platform +from homeassistant.const import CONF_USERNAME, Platform from homeassistant.core import HomeAssistant -from .const import VALID_ENTRY_DATA +from .const import VALID_ENTRY_DATA_CLOUD from tests.common import MockConfigEntry, load_json_object_fixture @@ -30,12 +30,18 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]: @pytest.fixture -def mock_config_entry() -> MockConfigEntry: +def mock_config_entry_data() -> dict[str, Any]: + """Return the default mocked config entry data.""" + return VALID_ENTRY_DATA_CLOUD + + +@pytest.fixture +def mock_config_entry(mock_config_entry_data: dict[str, Any]) -> MockConfigEntry: """Return the default mocked config entry.""" return MockConfigEntry( - title="username", + title=mock_config_entry_data[CONF_USERNAME], domain=DOMAIN, - data=VALID_ENTRY_DATA, + data=mock_config_entry_data, ) @@ -62,7 +68,7 @@ def mock_authenticator(device_fixture: str) -> Generator[Mock, None, None]: load_json_object_fixture(f"devices/{device_fixture}/device.json", DOMAIN) ] - def post_authenticated( + async def post_authenticated( path: str, json: dict[str, Any], *, @@ -89,8 +95,11 @@ def mock_mqtt_client(mock_authenticator: Mock) -> Mock: with patch( "homeassistant.components.ecovacs.controller.MqttClient", autospec=True, - ) as mock_mqtt_client: - client = mock_mqtt_client.return_value + ) as mock, patch( + "homeassistant.components.ecovacs.config_flow.MqttClient", + new=mock, + ): + client = mock.return_value client._authenticator = mock_authenticator client.subscribe.return_value = lambda: None yield client diff --git a/tests/components/ecovacs/const.py b/tests/components/ecovacs/const.py index f5100e69ee2..237c7fa5c85 100644 --- a/tests/components/ecovacs/const.py +++ b/tests/components/ecovacs/const.py @@ -1,13 +1,28 @@ """Test ecovacs constants.""" -from homeassistant.components.ecovacs.const import CONF_CONTINENT +from homeassistant.components.ecovacs.const import ( + CONF_CONTINENT, + CONF_OVERRIDE_MQTT_URL, + CONF_OVERRIDE_REST_URL, + CONF_VERIFY_MQTT_CERTIFICATE, +) from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME -VALID_ENTRY_DATA = { - CONF_USERNAME: "username", +VALID_ENTRY_DATA_CLOUD = { + CONF_USERNAME: "username@cloud", CONF_PASSWORD: "password", CONF_COUNTRY: "IT", } -IMPORT_DATA = VALID_ENTRY_DATA | {CONF_CONTINENT: "EU"} +VALID_ENTRY_DATA_SELF_HOSTED = VALID_ENTRY_DATA_CLOUD | { + CONF_USERNAME: "username@self-hosted", + CONF_OVERRIDE_REST_URL: "http://localhost:8000", + CONF_OVERRIDE_MQTT_URL: "mqtt://localhost:1883", +} + +VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT = VALID_ENTRY_DATA_SELF_HOSTED | { + CONF_VERIFY_MQTT_CERTIFICATE: True, +} + +IMPORT_DATA = VALID_ENTRY_DATA_CLOUD | {CONF_CONTINENT: "EU"} diff --git a/tests/components/ecovacs/snapshots/test_diagnostics.ambr b/tests/components/ecovacs/snapshots/test_diagnostics.ambr index 9b27883745b..a4291f9fe25 100644 --- a/tests/components/ecovacs/snapshots/test_diagnostics.ambr +++ b/tests/components/ecovacs/snapshots/test_diagnostics.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_diagnostics +# name: test_diagnostics[username@cloud] dict({ 'config': dict({ 'data': dict({ @@ -48,3 +48,54 @@ ]), }) # --- +# name: test_diagnostics[username@self-hosted] + dict({ + 'config': dict({ + 'data': dict({ + 'country': 'IT', + 'override_mqtt_url': '**REDACTED**', + 'override_rest_url': '**REDACTED**', + 'password': '**REDACTED**', + 'username': '**REDACTED**', + }), + 'disabled_by': None, + 'domain': 'ecovacs', + 'minor_version': 1, + 'options': dict({ + }), + 'pref_disable_new_entities': False, + 'pref_disable_polling': False, + 'source': 'user', + 'title': '**REDACTED**', + 'unique_id': None, + 'version': 1, + }), + 'devices': list([ + dict({ + 'UILogicId': 'DX_9G', + 'class': 'yna5xi', + 'company': 'eco-ng', + 'deviceName': 'DEEBOT OZMO 950 Series', + 'did': '**REDACTED**', + 'homeSort': 9999, + 'icon': 'https://portal-ww.ecouser.net/api/pim/file/get/606278df4a84d700082b39f1', + 'materialNo': '110-1820-0101', + 'model': 'DX9G', + 'name': '**REDACTED**', + 'nick': 'Ozmo 950', + 'otaUpgrade': dict({ + }), + 'pid': '5c19a91ca1e6ee000178224a', + 'product_category': 'DEEBOT', + 'resource': 'upQ6', + 'service': dict({ + 'jmq': 'jmq-ngiot-eu.dc.ww.ecouser.net', + 'mqs': 'api-ngiot.dc-as.ww.ecouser.net', + }), + 'status': 1, + }), + ]), + 'legacy_devices': list([ + ]), + }) +# --- diff --git a/tests/components/ecovacs/test_config_flow.py b/tests/components/ecovacs/test_config_flow.py index 64f0758dc1f..5e02ec7dede 100644 --- a/tests/components/ecovacs/test_config_flow.py +++ b/tests/components/ecovacs/test_config_flow.py @@ -1,86 +1,307 @@ """Test Ecovacs config flow.""" +from collections.abc import Awaitable, Callable +import ssl from typing import Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock, patch from aiohttp import ClientError -from deebot_client.exceptions import InvalidAuthenticationError +from deebot_client.exceptions import InvalidAuthenticationError, MqttError +from deebot_client.mqtt_client import create_mqtt_config import pytest -from homeassistant.components.ecovacs.const import DOMAIN +from homeassistant.components.ecovacs.const import ( + CONF_CONTINENT, + CONF_OVERRIDE_MQTT_URL, + CONF_OVERRIDE_REST_URL, + CONF_VERIFY_MQTT_CERTIFICATE, + DOMAIN, + InstanceMode, +) from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER -from homeassistant.const import CONF_USERNAME +from homeassistant.const import CONF_COUNTRY, CONF_MODE, CONF_USERNAME from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.data_entry_flow import FlowResultType from homeassistant.helpers import issue_registry as ir -from .const import IMPORT_DATA, VALID_ENTRY_DATA +from .const import ( + IMPORT_DATA, + VALID_ENTRY_DATA_CLOUD, + VALID_ENTRY_DATA_SELF_HOSTED, + VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT, +) from tests.common import MockConfigEntry +_USER_STEP_SELF_HOSTED = {CONF_MODE: InstanceMode.SELF_HOSTED} -async def _test_user_flow(hass: HomeAssistant) -> dict[str, Any]: +_TEST_FN_AUTH_ARG = "user_input_auth" +_TEST_FN_USER_ARG = "user_input_user" + + +async def _test_user_flow( + hass: HomeAssistant, + user_input_auth: dict[str, Any], +) -> dict[str, Any]: """Test config flow.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "auth" + assert not result["errors"] + return await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=VALID_ENTRY_DATA, + user_input=user_input_auth, ) +async def _test_user_flow_show_advanced_options( + hass: HomeAssistant, + *, + user_input_auth: dict[str, Any], + user_input_user: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Test config flow.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER, "show_advanced_options": True}, + ) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" + assert not result["errors"] + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input_user or {}, + ) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "auth" + assert not result["errors"] + + return await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input_auth, + ) + + +@pytest.mark.parametrize( + ("test_fn", "test_fn_args", "entry_data"), + [ + ( + _test_user_flow_show_advanced_options, + {_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD}, + VALID_ENTRY_DATA_CLOUD, + ), + ( + _test_user_flow_show_advanced_options, + { + _TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_SELF_HOSTED, + _TEST_FN_USER_ARG: _USER_STEP_SELF_HOSTED, + }, + VALID_ENTRY_DATA_SELF_HOSTED, + ), + ( + _test_user_flow, + {_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD}, + VALID_ENTRY_DATA_CLOUD, + ), + ], + ids=["advanced_cloud", "advanced_self_hosted", "cloud"], +) async def test_user_flow( hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, + test_fn: Callable[[HomeAssistant, dict[str, Any]], Awaitable[dict[str, Any]]] + | Callable[ + [HomeAssistant, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]] + ], + test_fn_args: dict[str, Any], + entry_data: dict[str, Any], ) -> None: """Test the user config flow.""" - result = await _test_user_flow(hass) + result = await test_fn( + hass, + **test_fn_args, + ) assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] - assert result["data"] == VALID_ENTRY_DATA + assert result["title"] == entry_data[CONF_USERNAME] + assert result["data"] == entry_data mock_setup_entry.assert_called() mock_authenticator_authenticate.assert_called() + mock_mqtt_client.verify_config.assert_called() + + +def _cannot_connect_error(user_input: dict[str, Any]) -> str: + field = "base" + if CONF_OVERRIDE_MQTT_URL in user_input: + field = CONF_OVERRIDE_MQTT_URL + + return {field: "cannot_connect"} @pytest.mark.parametrize( - ("side_effect", "reason"), + ("side_effect_mqtt", "errors_mqtt"), + [ + (MqttError, _cannot_connect_error), + (InvalidAuthenticationError, lambda _: {"base": "invalid_auth"}), + (Exception, lambda _: {"base": "unknown"}), + ], + ids=["cannot_connect", "invalid_auth", "unknown"], +) +@pytest.mark.parametrize( + ("side_effect_rest", "reason_rest"), [ (ClientError, "cannot_connect"), (InvalidAuthenticationError, "invalid_auth"), (Exception, "unknown"), ], + ids=["cannot_connect", "invalid_auth", "unknown"], ) -async def test_user_flow_error( +@pytest.mark.parametrize( + ("test_fn", "test_fn_args", "entry_data"), + [ + ( + _test_user_flow_show_advanced_options, + {_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD}, + VALID_ENTRY_DATA_CLOUD, + ), + ( + _test_user_flow_show_advanced_options, + { + _TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_SELF_HOSTED, + _TEST_FN_USER_ARG: _USER_STEP_SELF_HOSTED, + }, + VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT, + ), + ( + _test_user_flow, + {_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD}, + VALID_ENTRY_DATA_CLOUD, + ), + ], + ids=["advanced_cloud", "advanced_self_hosted", "cloud"], +) +async def test_user_flow_raise_error( hass: HomeAssistant, - side_effect: Exception, - reason: str, mock_setup_entry: AsyncMock, mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, + side_effect_rest: Exception, + reason_rest: str, + side_effect_mqtt: Exception, + errors_mqtt: Callable[[dict[str, Any]], str], + test_fn: Callable[[HomeAssistant, dict[str, Any]], Awaitable[dict[str, Any]]] + | Callable[ + [HomeAssistant, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]] + ], + test_fn_args: dict[str, Any], + entry_data: dict[str, Any], ) -> None: - """Test handling invalid connection.""" + """Test handling error on library calls.""" + user_input_auth = test_fn_args[_TEST_FN_AUTH_ARG] - mock_authenticator_authenticate.side_effect = side_effect - - result = await _test_user_flow(hass) + # Authenticator raises error + mock_authenticator_authenticate.side_effect = side_effect_rest + result = await test_fn( + hass, + **test_fn_args, + ) assert result["type"] == FlowResultType.FORM - assert result["step_id"] == "user" - assert result["errors"] == {"base": reason} + assert result["step_id"] == "auth" + assert result["errors"] == {"base": reason_rest} mock_authenticator_authenticate.assert_called() + mock_mqtt_client.verify_config.assert_not_called() mock_setup_entry.assert_not_called() mock_authenticator_authenticate.reset_mock(side_effect=True) + + # MQTT raises error + mock_mqtt_client.verify_config.side_effect = side_effect_mqtt result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=VALID_ENTRY_DATA, + user_input=user_input_auth, + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "auth" + assert result["errors"] == errors_mqtt(user_input_auth) + mock_authenticator_authenticate.assert_called() + mock_mqtt_client.verify_config.assert_called() + mock_setup_entry.assert_not_called() + + mock_authenticator_authenticate.reset_mock(side_effect=True) + mock_mqtt_client.verify_config.reset_mock(side_effect=True) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input_auth, ) assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] - assert result["data"] == VALID_ENTRY_DATA + assert result["title"] == entry_data[CONF_USERNAME] + assert result["data"] == entry_data mock_setup_entry.assert_called() mock_authenticator_authenticate.assert_called() + mock_mqtt_client.verify_config.assert_called() + + +async def test_user_flow_self_hosted_error( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, +) -> None: + """Test handling selfhosted errors and custom ssl context.""" + + result = await _test_user_flow_show_advanced_options( + hass, + user_input_auth=VALID_ENTRY_DATA_SELF_HOSTED + | { + CONF_OVERRIDE_REST_URL: "bla://localhost:8000", + CONF_OVERRIDE_MQTT_URL: "mqtt://", + }, + user_input_user=_USER_STEP_SELF_HOSTED, + ) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "auth" + assert result["errors"] == { + CONF_OVERRIDE_REST_URL: "invalid_url_schema_override_rest_url", + CONF_OVERRIDE_MQTT_URL: "invalid_url", + } + mock_authenticator_authenticate.assert_not_called() + mock_mqtt_client.verify_config.assert_not_called() + mock_setup_entry.assert_not_called() + + # Check that the schema includes select box to disable ssl verification of mqtt + assert CONF_VERIFY_MQTT_CERTIFICATE in result["data_schema"].schema + + data = VALID_ENTRY_DATA_SELF_HOSTED | {CONF_VERIFY_MQTT_CERTIFICATE: False} + with patch( + "homeassistant.components.ecovacs.config_flow.create_mqtt_config", + wraps=create_mqtt_config, + ) as mock_create_mqtt_config: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=data, + ) + mock_create_mqtt_config.assert_called_once() + ssl_context = mock_create_mqtt_config.call_args[1]["ssl_context"] + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context.verify_mode == ssl.CERT_NONE + assert ssl_context.check_hostname is False + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == data[CONF_USERNAME] + assert result["data"] == data + mock_setup_entry.assert_called() + mock_authenticator_authenticate.assert_called() + mock_mqtt_client.verify_config.assert_called() async def test_import_flow( @@ -88,6 +309,7 @@ async def test_import_flow( issue_registry: ir.IssueRegistry, mock_setup_entry: AsyncMock, mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, ) -> None: """Test importing yaml config.""" result = await hass.config_entries.flow.async_init( @@ -98,17 +320,18 @@ async def test_import_flow( mock_authenticator_authenticate.assert_called() assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] - assert result["data"] == VALID_ENTRY_DATA + assert result["title"] == VALID_ENTRY_DATA_CLOUD[CONF_USERNAME] + assert result["data"] == VALID_ENTRY_DATA_CLOUD assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues mock_setup_entry.assert_called() + mock_mqtt_client.verify_config.assert_called() async def test_import_flow_already_configured( hass: HomeAssistant, issue_registry: ir.IssueRegistry ) -> None: """Test importing yaml config where entry already configured.""" - entry = MockConfigEntry(domain=DOMAIN, data=VALID_ENTRY_DATA) + entry = MockConfigEntry(domain=DOMAIN, data=VALID_ENTRY_DATA_CLOUD) entry.add_to_hass(hass) result = await hass.config_entries.flow.async_init( @@ -121,6 +344,7 @@ async def test_import_flow_already_configured( assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues +@pytest.mark.parametrize("show_advanced_options", [True, False]) @pytest.mark.parametrize( ("side_effect", "reason"), [ @@ -131,17 +355,22 @@ async def test_import_flow_already_configured( ) async def test_import_flow_error( hass: HomeAssistant, - side_effect: Exception, - reason: str, issue_registry: ir.IssueRegistry, mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, + side_effect: Exception, + reason: str, + show_advanced_options: bool, ) -> None: """Test handling invalid connection.""" mock_authenticator_authenticate.side_effect = side_effect result = await hass.config_entries.flow.async_init( DOMAIN, - context={"source": SOURCE_IMPORT}, + context={ + "source": SOURCE_IMPORT, + "show_advanced_options": show_advanced_options, + }, data=IMPORT_DATA.copy(), ) assert result["type"] == FlowResultType.ABORT @@ -151,3 +380,38 @@ async def test_import_flow_error( f"deprecated_yaml_import_issue_{reason}", ) in issue_registry.issues mock_authenticator_authenticate.assert_called() + + +@pytest.mark.parametrize("show_advanced_options", [True, False]) +@pytest.mark.parametrize( + ("reason", "user_input"), + [ + ("invalid_country_length", IMPORT_DATA | {CONF_COUNTRY: "too_long"}), + ("invalid_country_length", IMPORT_DATA | {CONF_COUNTRY: "a"}), # too short + ("invalid_continent_length", IMPORT_DATA | {CONF_CONTINENT: "too_long"}), + ("invalid_continent_length", IMPORT_DATA | {CONF_CONTINENT: "a"}), # too short + ("continent_not_match", IMPORT_DATA | {CONF_CONTINENT: "AA"}), + ], +) +async def test_import_flow_invalid_data( + hass: HomeAssistant, + issue_registry: ir.IssueRegistry, + reason: str, + user_input: dict[str, Any], + show_advanced_options: bool, +) -> None: + """Test handling invalid connection.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={ + "source": SOURCE_IMPORT, + "show_advanced_options": show_advanced_options, + }, + data=user_input, + ) + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == reason + assert ( + DOMAIN, + f"deprecated_yaml_import_issue_{reason}", + ) in issue_registry.issues diff --git a/tests/components/ecovacs/test_diagnostics.py b/tests/components/ecovacs/test_diagnostics.py index 8244efd7fec..b025db43cc0 100644 --- a/tests/components/ecovacs/test_diagnostics.py +++ b/tests/components/ecovacs/test_diagnostics.py @@ -1,15 +1,24 @@ """Tests for diagnostics data.""" +import pytest from syrupy.assertion import SnapshotAssertion from syrupy.filters import props +from homeassistant.const import CONF_USERNAME from homeassistant.core import HomeAssistant +from .const import VALID_ENTRY_DATA_CLOUD, VALID_ENTRY_DATA_SELF_HOSTED + from tests.common import MockConfigEntry from tests.components.diagnostics import get_diagnostics_for_config_entry from tests.typing import ClientSessionGenerator +@pytest.mark.parametrize( + "mock_config_entry_data", + [VALID_ENTRY_DATA_CLOUD, VALID_ENTRY_DATA_SELF_HOSTED], + ids=lambda data: data[CONF_USERNAME], +) async def test_diagnostics( hass: HomeAssistant, hass_client: ClientSessionGenerator, diff --git a/tests/components/ecovacs/test_init.py b/tests/components/ecovacs/test_init.py index 11fe403ca9c..3a344609961 100644 --- a/tests/components/ecovacs/test_init.py +++ b/tests/components/ecovacs/test_init.py @@ -87,6 +87,7 @@ async def test_async_setup_import( config_entries_expected: int, mock_setup_entry: AsyncMock, mock_authenticator_authenticate: AsyncMock, + mock_mqtt_client: Mock, ) -> None: """Test async_setup config import.""" assert len(hass.config_entries.async_entries(DOMAIN)) == 0 @@ -95,6 +96,7 @@ async def test_async_setup_import( assert len(hass.config_entries.async_entries(DOMAIN)) == config_entries_expected assert mock_setup_entry.call_count == config_entries_expected assert mock_authenticator_authenticate.call_count == config_entries_expected + assert mock_mqtt_client.verify_config.call_count == config_entries_expected async def test_devices_in_dr(