From 45eebf32856d521dc70741def9bfd69b4bafb6a8 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:09:21 +0200 Subject: [PATCH] Use reauth_confirm in sharkiq (#124762) --- .../components/sharkiq/config_flow.py | 10 ++++- homeassistant/components/sharkiq/strings.json | 2 +- tests/components/sharkiq/test_config_flow.py | 37 +++++++++++-------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/sharkiq/config_flow.py b/homeassistant/components/sharkiq/config_flow.py index 492b8f2a365..87367fcf093 100644 --- a/homeassistant/components/sharkiq/config_flow.py +++ b/homeassistant/components/sharkiq/config_flow.py @@ -116,9 +116,15 @@ class SharkIqConfigFlow(ConfigFlow, domain=DOMAIN): ) async def async_step_reauth( - self, user_input: Mapping[str, Any] + self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: """Handle re-auth if login is invalid.""" + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle a flow initiated by reauthentication.""" errors: dict[str, str] = {} if user_input is not None: @@ -134,7 +140,7 @@ class SharkIqConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_abort(reason=errors["base"]) return self.async_show_form( - step_id="reauth", + step_id="reauth_confirm", data_schema=SHARKIQ_SCHEMA, errors=errors, ) diff --git a/homeassistant/components/sharkiq/strings.json b/homeassistant/components/sharkiq/strings.json index 63d4f6af48b..40b569e13b7 100644 --- a/homeassistant/components/sharkiq/strings.json +++ b/homeassistant/components/sharkiq/strings.json @@ -13,7 +13,7 @@ "region": "Shark IQ uses different services in the EU. Select your region to connect to the correct service for your account." } }, - "reauth": { + "reauth_confirm": { "data": { "username": "[%key:common::config_flow::data::username%]", "password": "[%key:common::config_flow::data::password%]", diff --git a/tests/components/sharkiq/test_config_flow.py b/tests/components/sharkiq/test_config_flow.py index cf75bff1686..ae037834c57 100644 --- a/tests/components/sharkiq/test_config_flow.py +++ b/tests/components/sharkiq/test_config_flow.py @@ -96,18 +96,22 @@ async def test_form_error(hass: HomeAssistant, exc: Exception, base_error: str) async def test_reauth_success(hass: HomeAssistant) -> None: """Test reauth flow.""" - with patch("sharkiq.AylaApi.async_sign_in", return_value=True): - mock_config = MockConfigEntry(domain=DOMAIN, unique_id=UNIQUE_ID, data=CONFIG) - mock_config.add_to_hass(hass) + mock_config = MockConfigEntry(domain=DOMAIN, unique_id=UNIQUE_ID, data=CONFIG) + mock_config.add_to_hass(hass) - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_REAUTH, "unique_id": UNIQUE_ID}, - data=CONFIG, + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH, "unique_id": UNIQUE_ID}, + data=mock_config.data, + ) + + with patch("sharkiq.AylaApi.async_sign_in", return_value=True): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=CONFIG ) - assert result["type"] is FlowResultType.ABORT - assert result["reason"] == "reauth_successful" + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" @pytest.mark.parametrize( @@ -127,13 +131,16 @@ async def test_reauth( msg: str, ) -> None: """Test reauth failures.""" - with patch("sharkiq.AylaApi.async_sign_in", side_effect=side_effect): - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_REAUTH, "unique_id": UNIQUE_ID}, - data=CONFIG, - ) + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH, "unique_id": UNIQUE_ID}, + data=CONFIG, + ) + with patch("sharkiq.AylaApi.async_sign_in", side_effect=side_effect): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=CONFIG + ) msg_value = result[msg_field] if msg_field == "errors": msg_value = msg_value.get("base")