diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index e23280c7f2b..bea8a900a83 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -12,9 +12,9 @@ from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.x509 import load_pem_x509_certificate import voluptuous as vol -from homeassistant import config_entries from homeassistant.components.file_upload import process_uploaded_file from homeassistant.components.hassio import HassioServiceInfo +from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow from homeassistant.const import ( CONF_CLIENT_ID, CONF_DISCOVERY, @@ -25,7 +25,7 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_USERNAME, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import config_validation as cv from homeassistant.helpers.json import json_dumps @@ -154,7 +154,7 @@ CERT_UPLOAD_SELECTOR = FileSelector( KEY_UPLOAD_SELECTOR = FileSelector(FileSelectorConfig(accept=".key,application/pkcs8")) -class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): +class FlowHandler(ConfigFlow, domain=DOMAIN): """Handle a config flow.""" VERSION = 1 @@ -164,7 +164,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): @staticmethod @callback def async_get_options_flow( - config_entry: config_entries.ConfigEntry, + config_entry: ConfigEntry, ) -> MQTTOptionsFlowHandler: """Get the options flow for this handler.""" return MQTTOptionsFlowHandler(config_entry) @@ -186,7 +186,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): fields: OrderedDict[Any, Any] = OrderedDict() validated_user_input: dict[str, Any] = {} if await async_get_broker_settings( - self.hass, + self, fields, None, user_input, @@ -255,10 +255,10 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) -class MQTTOptionsFlowHandler(config_entries.OptionsFlow): +class MQTTOptionsFlowHandler(OptionsFlow): """Handle MQTT options.""" - def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + def __init__(self, config_entry: ConfigEntry) -> None: """Initialize MQTT options flow.""" self.config_entry = config_entry self.broker_config: dict[str, str | int] = {} @@ -276,7 +276,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): fields: OrderedDict[Any, Any] = OrderedDict() validated_user_input: dict[str, Any] = {} if await async_get_broker_settings( - self.hass, + self, fields, self.config_entry.data, user_input, @@ -448,7 +448,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): async def async_get_broker_settings( - hass: HomeAssistant, + flow: ConfigFlow | OptionsFlow, fields: OrderedDict[Any, Any], entry_config: MappingProxyType[str, Any] | None, user_input: dict[str, Any] | None, @@ -461,6 +461,7 @@ async def async_get_broker_settings( or when the advanced_broker_options checkbox was selected. Returns True when settings are collected successfully. """ + hass = flow.hass advanced_broker_options: bool = False user_input_basic: dict[str, Any] = {} current_config: dict[str, Any] = ( @@ -639,9 +640,12 @@ async def async_get_broker_settings( description={"suggested_value": current_pass}, ) ] = PASSWORD_SELECTOR - # show advanced options checkbox if requested + # show advanced options checkbox if requested and + # advanced options are enabled # or when the defaults of advanced options are overridden if not advanced_broker_options: + if not flow.show_advanced_options: + return False fields[ vol.Optional( ADVANCED_OPTIONS, diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index 8f3846a7376..2ebc4a50ef0 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -208,7 +208,8 @@ async def test_user_v5_connection_works( mock_try_connection.return_value = True result = await hass.config_entries.flow.async_init( - "mqtt", context={"source": config_entries.SOURCE_USER} + "mqtt", + context={"source": config_entries.SOURCE_USER, "show_advanced_options": True}, ) assert result["type"] == "form" @@ -1015,7 +1016,9 @@ async def test_skipping_advanced_options( mqtt_mock.async_connect.reset_mock() - result = await hass.config_entries.options.async_init(config_entry.entry_id) + result = await hass.config_entries.options.async_init( + config_entry.entry_id, context={"show_advanced_options": True} + ) assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["step_id"] == "broker" @@ -1268,7 +1271,9 @@ async def test_setup_with_advanced_settings( mock_try_connection.return_value = True - result = await hass.config_entries.options.async_init(config_entry.entry_id) + result = await hass.config_entries.options.async_init( + config_entry.entry_id, context={"show_advanced_options": True} + ) assert result["type"] == "form" assert result["step_id"] == "broker" assert result["data_schema"].schema["advanced_options"]