From 1c7c6163dd6c45b9db7e937190a4deac5cb04a02 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 24 Feb 2021 11:31:31 +0100 Subject: [PATCH] Save mysensors gateway type in config entry (#46981) --- .../components/mysensors/config_flow.py | 30 ++++++++++++------- .../components/mysensors/test_config_flow.py | 3 ++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/mysensors/config_flow.py b/homeassistant/components/mysensors/config_flow.py index 058b782d208..06ead121706 100644 --- a/homeassistant/components/mysensors/config_flow.py +++ b/homeassistant/components/mysensors/config_flow.py @@ -19,6 +19,7 @@ from homeassistant.components.mysensors import ( is_persistence_file, ) from homeassistant.config_entries import ConfigEntry +from homeassistant.core import callback import homeassistant.helpers.config_validation as cv from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION @@ -99,6 +100,10 @@ def _is_same_device( class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow.""" + def __init__(self) -> None: + """Set up config flow.""" + self._gw_type: Optional[str] = None + async def async_step_import(self, user_input: Optional[Dict[str, str]] = None): """Import a config entry. @@ -130,7 +135,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): schema = vol.Schema(schema) if user_input is not None: - gw_type = user_input[CONF_GATEWAY_TYPE] + gw_type = self._gw_type = user_input[CONF_GATEWAY_TYPE] input_pass = user_input if CONF_DEVICE in user_input else None if gw_type == CONF_GATEWAY_TYPE_MQTT: return await self.async_step_gw_mqtt(input_pass) @@ -149,9 +154,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input) ) if not errors: - return self.async_create_entry( - title=f"{user_input[CONF_DEVICE]}", data=user_input - ) + return self._async_create_entry(user_input) schema = _get_schema_common() schema[ @@ -177,9 +180,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.validate_common(CONF_GATEWAY_TYPE_TCP, errors, user_input) ) if not errors: - return self.async_create_entry( - title=f"{user_input[CONF_DEVICE]}", data=user_input - ) + return self._async_create_entry(user_input) schema = _get_schema_common() schema[vol.Required(CONF_DEVICE, default="127.0.0.1")] = str @@ -228,9 +229,8 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.validate_common(CONF_GATEWAY_TYPE_MQTT, errors, user_input) ) if not errors: - return self.async_create_entry( - title=f"{user_input[CONF_DEVICE]}", data=user_input - ) + return self._async_create_entry(user_input) + schema = _get_schema_common() schema[vol.Required(CONF_RETAIN, default=True)] = bool schema[vol.Required(CONF_TOPIC_IN_PREFIX)] = str @@ -241,6 +241,16 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): step_id="gw_mqtt", data_schema=schema, errors=errors ) + @callback + def _async_create_entry( + self, user_input: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """Create the config entry.""" + return self.async_create_entry( + title=f"{user_input[CONF_DEVICE]}", + data={**user_input, CONF_GATEWAY_TYPE: self._gw_type}, + ) + def _normalize_persistence_file(self, path: str) -> str: return os.path.realpath(os.path.normcase(self.hass.config.path(path))) diff --git a/tests/components/mysensors/test_config_flow.py b/tests/components/mysensors/test_config_flow.py index 6bfec3b102e..5fd9e3e7ea1 100644 --- a/tests/components/mysensors/test_config_flow.py +++ b/tests/components/mysensors/test_config_flow.py @@ -81,6 +81,7 @@ async def test_config_mqtt(hass: HomeAssistantType): CONF_TOPIC_IN_PREFIX: "bla", CONF_TOPIC_OUT_PREFIX: "blub", CONF_VERSION: "2.4", + CONF_GATEWAY_TYPE: "MQTT", } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -120,6 +121,7 @@ async def test_config_serial(hass: HomeAssistantType): CONF_DEVICE: "/dev/ttyACM0", CONF_BAUD_RATE: 115200, CONF_VERSION: "2.4", + CONF_GATEWAY_TYPE: "Serial", } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -156,6 +158,7 @@ async def test_config_tcp(hass: HomeAssistantType): CONF_DEVICE: "127.0.0.1", CONF_TCP_PORT: 5003, CONF_VERSION: "2.4", + CONF_GATEWAY_TYPE: "TCP", } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1