diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index 2fdfea8673f..d29dae735f8 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -259,10 +259,21 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): """ return self.async_create_entry(title=self.flow_impl.name, data=data) + async def async_step_discovery(self, user_input: dict = None) -> dict: + """Handle a flow initialized by discovery.""" + await self.async_set_unique_id(self.DOMAIN) + self._abort_if_unique_id_configured() + + assert self.hass is not None + if self.hass.config_entries.async_entries(self.DOMAIN): + return self.async_abort(reason="already_configured") + + return await self.async_step_pick_implementation() + async_step_user = async_step_pick_implementation - async_step_ssdp = async_step_pick_implementation - async_step_zeroconf = async_step_pick_implementation - async_step_homekit = async_step_pick_implementation + async_step_ssdp = async_step_discovery + async_step_zeroconf = async_step_discovery + async_step_homekit = async_step_discovery @classmethod def async_register_implementation( diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 366c295874d..a72f3f51ee7 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -122,6 +122,64 @@ async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl): assert result["reason"] == "authorize_url_timeout" +async def test_step_discovery(hass, flow_handler, local_impl): + """Check flow triggers from discovery.""" + hass.config.api.base_url = "https://example.com" + flow_handler.async_register_implementation(hass, local_impl) + config_entry_oauth2_flow.async_register_implementation( + hass, TEST_DOMAIN, MockOAuth2Implementation() + ) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "pick_implementation" + + +async def test_abort_discovered_multiple(hass, flow_handler, local_impl): + """Test if aborts when discovered multiple times.""" + hass.config.api.base_url = "https://example.com" + flow_handler.async_register_implementation(hass, local_impl) + config_entry_oauth2_flow.async_register_implementation( + hass, TEST_DOMAIN, MockOAuth2Implementation() + ) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_SSDP} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "pick_implementation" + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_in_progress" + + +async def test_abort_discovered_existing_entries(hass, flow_handler, local_impl): + """Test if abort discovery when entries exists.""" + hass.config.api.base_url = "https://example.com" + flow_handler.async_register_implementation(hass, local_impl) + config_entry_oauth2_flow.async_register_implementation( + hass, TEST_DOMAIN, MockOAuth2Implementation() + ) + + entry = MockConfigEntry(domain=TEST_DOMAIN, data={},) + entry.add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_SSDP} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + + async def test_full_flow( hass, flow_handler, local_impl, aiohttp_client, aioclient_mock ):