From b61218f90ebccfc5c5258218c09d1705c0cbc755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20S=C3=B8rensen?= Date: Thu, 24 Oct 2019 21:31:58 +0200 Subject: [PATCH] Tradfri config flow enhancements (#28179) --- homeassistant/components/tradfri/__init__.py | 5 +++-- homeassistant/components/tradfri/config_flow.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index bdfabb4b00a..9d1a43b240f 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -8,6 +8,7 @@ from pytradfri.api.aiocoap_api import APIFactory import homeassistant.helpers.config_validation as cv from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.util.json import load_json from . import config_flow # noqa pylint_disable=unused-import from .const import ( @@ -113,8 +114,8 @@ async def async_setup_entry(hass, entry): try: gateway_info = await api(gateway.get_gateway_info()) except RequestError: - _LOGGER.error("Tradfri setup failed.") - return False + await factory.shutdown() + raise ConfigEntryNotReady hass.data.setdefault(KEY_API, {})[entry.entry_id] = api hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway diff --git a/homeassistant/components/tradfri/config_flow.py b/homeassistant/components/tradfri/config_flow.py index bdb195cf53f..24c3fbc1876 100644 --- a/homeassistant/components/tradfri/config_flow.py +++ b/homeassistant/components/tradfri/config_flow.py @@ -64,13 +64,17 @@ class FlowHandler(config_entries.ConfigFlow): errors[KEY_SECURITY_CODE] = err.code else: errors["base"] = err.code + else: + user_input = {} fields = OrderedDict() if self._host is None: - fields[vol.Required(CONF_HOST)] = str + fields[vol.Required(CONF_HOST, default=user_input.get(CONF_HOST))] = str - fields[vol.Required(KEY_SECURITY_CODE)] = str + fields[ + vol.Required(KEY_SECURITY_CODE, default=user_input.get(KEY_SECURITY_CODE)) + ] = str return self.async_show_form( step_id="auth", data_schema=vol.Schema(fields), errors=errors