Add new helper for matching reauth/reconfigure config flows (#127565)

pull/127923/head
epenet 2024-10-08 10:07:36 +02:00 committed by GitHub
parent 15a1a83729
commit 2c664efb3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 102 additions and 39 deletions

View File

@ -21,7 +21,6 @@ from homeassistant.config_entries import (
)
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_SOURCE, CONF_USERNAME
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
@ -75,7 +74,6 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
_existing_entry_data: Mapping[str, Any] | None = None
_existing_entry_unique_id: str | None = None
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -85,15 +83,12 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None:
unique_id = f"{user_input[CONF_REGION]}-{user_input[CONF_USERNAME]}"
await self.async_set_unique_id(unique_id)
if self.source not in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
await self.async_set_unique_id(unique_id)
if self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
self._abort_if_unique_id_mismatch(reason="account_mismatch")
else:
self._abort_if_unique_id_configured()
elif (
self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}
and unique_id != self._existing_entry_unique_id
):
raise AbortFlow("account_mismatch")
info = None
try:
@ -135,16 +130,13 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Handle configuration by re-auth."""
self._existing_entry_data = entry_data
self._existing_entry_unique_id = self._get_reauth_entry().unique_id
return await self.async_step_user()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle a reconfiguration flow initialized by the user."""
reconfigure_entry = self._get_reconfigure_entry()
self._existing_entry_data = reconfigure_entry.data
self._existing_entry_unique_id = reconfigure_entry.unique_id
self._existing_entry_data = self._get_reconfigure_entry().data
return await self.async_step_user()
@staticmethod

View File

@ -50,11 +50,9 @@ class SpotifyFlowHandler(
await self.async_set_unique_id(current_user["id"])
if self.source == SOURCE_REAUTH:
reauth_entry = self._get_reauth_entry()
if reauth_entry.data["id"] != current_user["id"]:
return self.async_abort(reason="reauth_account_mismatch")
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
return self.async_update_reload_and_abort(
reauth_entry, title=name, data=data
self._get_reauth_entry(), title=name, data=data
)
return self.async_create_entry(title=name, data=data)

View File

@ -8,7 +8,7 @@ from typing import Any
import jwt
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
from homeassistant.helpers import config_entry_oauth2_flow
from .const import DOMAIN, LOGGER
@ -21,7 +21,6 @@ class OAuth2FlowHandler(
"""Config flow to handle Tesla Fleet API OAuth2 authentication."""
DOMAIN = DOMAIN
reauth_entry: ConfigEntry | None = None
@property
def logger(self) -> logging.Logger:
@ -50,32 +49,19 @@ class OAuth2FlowHandler(
)
uid = token["sub"]
if not self.reauth_entry:
await self.async_set_unique_id(uid)
self._abort_if_unique_id_configured()
return self.async_create_entry(title=uid, data=data)
if self.reauth_entry.unique_id == uid:
self.hass.config_entries.async_update_entry(
self.reauth_entry,
data=data,
await self.async_set_unique_id(uid)
if self.source == SOURCE_REAUTH:
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
return self.async_update_reload_and_abort(
self._get_reauth_entry(), data=data
)
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
return self.async_abort(reason="reauth_successful")
return self.async_abort(
reason="reauth_account_mismatch",
description_placeholders={"title": self.reauth_entry.title},
)
self._abort_if_unique_id_configured()
return self.async_create_entry(title=uid, data=data)
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Perform reauth upon an API authentication error."""
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(

View File

@ -2432,6 +2432,26 @@ class ConfigFlow(ConfigEntryBaseFlow):
self._async_current_entries(include_ignore=False), match_dict
)
@callback
def _abort_if_unique_id_mismatch(
self,
*,
reason: str = "unique_id_mismatch",
) -> None:
"""Abort if the unique ID does not match the reauth/reconfigure context.
Requires strings.json entry corresponding to the `reason` parameter
in user visible flows.
"""
if (
self.source == SOURCE_REAUTH
and self._get_reauth_entry().unique_id != self.unique_id
) or (
self.source == SOURCE_RECONFIGURE
and self._get_reconfigure_entry().unique_id != self.unique_id
):
raise data_entry_flow.AbortFlow(reason)
@callback
def _abort_if_unique_id_configured(
self,

View File

@ -6677,6 +6677,73 @@ async def test_reauth_helper_alignment(
assert helper_flow_init_data == reauth_flow_init_data
@pytest.mark.parametrize(
("original_unique_id", "new_unique_id", "reason"),
[
("unique", "unique", "success"),
(None, None, "success"),
("unique", "new", "unique_id_mismatch"),
("unique", None, "unique_id_mismatch"),
(None, "new", "unique_id_mismatch"),
],
)
@pytest.mark.parametrize(
"source",
[config_entries.SOURCE_REAUTH, config_entries.SOURCE_RECONFIGURE],
)
async def test_abort_if_unique_id_mismatch(
hass: HomeAssistant,
source: str,
original_unique_id: str | None,
new_unique_id: str | None,
reason: str,
) -> None:
"""Test to check if_unique_id_mismatch behavior."""
entry = MockConfigEntry(
title="From config flow",
domain="test",
entry_id="01J915Q6T9F6G5V0QJX6HBC94T",
data={"host": "any", "port": 123},
unique_id=original_unique_id,
)
entry.add_to_hass(hass)
mock_setup_entry = AsyncMock(return_value=True)
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
"""Test user step."""
return await self._async_step_confirm()
async def async_step_reauth(self, entry_data):
"""Test reauth step."""
return await self._async_step_confirm()
async def async_step_reconfigure(self, user_input=None):
"""Test reauth step."""
return await self._async_step_confirm()
async def _async_step_confirm(self):
"""Confirm input."""
await self.async_set_unique_id(new_unique_id)
self._abort_if_unique_id_mismatch()
return self.async_abort(reason="success")
with mock_config_flow("test", TestFlow):
if source == config_entries.SOURCE_REAUTH:
result = await entry.start_reauth_flow(hass)
elif source == config_entries.SOURCE_RECONFIGURE:
result = await entry.start_reconfigure_flow(hass)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == reason
def test_state_not_stored_in_storage() -> None:
"""Test that state is not stored in storage.