Add new helper for matching reauth/reconfigure config flows (#127565)
parent
15a1a83729
commit
2c664efb3c
|
@ -21,7 +21,6 @@ from homeassistant.config_entries import (
|
|||
)
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_SOURCE, CONF_USERNAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import AbortFlow
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
|
||||
|
||||
|
@ -75,7 +74,6 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
VERSION = 1
|
||||
|
||||
_existing_entry_data: Mapping[str, Any] | None = None
|
||||
_existing_entry_unique_id: str | None = None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
@ -85,15 +83,12 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
|
||||
if user_input is not None:
|
||||
unique_id = f"{user_input[CONF_REGION]}-{user_input[CONF_USERNAME]}"
|
||||
await self.async_set_unique_id(unique_id)
|
||||
|
||||
if self.source not in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
|
||||
await self.async_set_unique_id(unique_id)
|
||||
if self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
|
||||
self._abort_if_unique_id_mismatch(reason="account_mismatch")
|
||||
else:
|
||||
self._abort_if_unique_id_configured()
|
||||
elif (
|
||||
self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}
|
||||
and unique_id != self._existing_entry_unique_id
|
||||
):
|
||||
raise AbortFlow("account_mismatch")
|
||||
|
||||
info = None
|
||||
try:
|
||||
|
@ -135,16 +130,13 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
) -> ConfigFlowResult:
|
||||
"""Handle configuration by re-auth."""
|
||||
self._existing_entry_data = entry_data
|
||||
self._existing_entry_unique_id = self._get_reauth_entry().unique_id
|
||||
return await self.async_step_user()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a reconfiguration flow initialized by the user."""
|
||||
reconfigure_entry = self._get_reconfigure_entry()
|
||||
self._existing_entry_data = reconfigure_entry.data
|
||||
self._existing_entry_unique_id = reconfigure_entry.unique_id
|
||||
self._existing_entry_data = self._get_reconfigure_entry().data
|
||||
return await self.async_step_user()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -50,11 +50,9 @@ class SpotifyFlowHandler(
|
|||
await self.async_set_unique_id(current_user["id"])
|
||||
|
||||
if self.source == SOURCE_REAUTH:
|
||||
reauth_entry = self._get_reauth_entry()
|
||||
if reauth_entry.data["id"] != current_user["id"]:
|
||||
return self.async_abort(reason="reauth_account_mismatch")
|
||||
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
|
||||
return self.async_update_reload_and_abort(
|
||||
reauth_entry, title=name, data=data
|
||||
self._get_reauth_entry(), title=name, data=data
|
||||
)
|
||||
return self.async_create_entry(title=name, data=data)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
import jwt
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from .const import DOMAIN, LOGGER
|
||||
|
@ -21,7 +21,6 @@ class OAuth2FlowHandler(
|
|||
"""Config flow to handle Tesla Fleet API OAuth2 authentication."""
|
||||
|
||||
DOMAIN = DOMAIN
|
||||
reauth_entry: ConfigEntry | None = None
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
|
@ -50,32 +49,19 @@ class OAuth2FlowHandler(
|
|||
)
|
||||
uid = token["sub"]
|
||||
|
||||
if not self.reauth_entry:
|
||||
await self.async_set_unique_id(uid)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
return self.async_create_entry(title=uid, data=data)
|
||||
|
||||
if self.reauth_entry.unique_id == uid:
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.reauth_entry,
|
||||
data=data,
|
||||
await self.async_set_unique_id(uid)
|
||||
if self.source == SOURCE_REAUTH:
|
||||
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
|
||||
return self.async_update_reload_and_abort(
|
||||
self._get_reauth_entry(), data=data
|
||||
)
|
||||
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
return self.async_abort(
|
||||
reason="reauth_account_mismatch",
|
||||
description_placeholders={"title": self.reauth_entry.title},
|
||||
)
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(title=uid, data=data)
|
||||
|
||||
async def async_step_reauth(
|
||||
self, entry_data: Mapping[str, Any]
|
||||
) -> ConfigFlowResult:
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
self.reauth_entry = self.hass.config_entries.async_get_entry(
|
||||
self.context["entry_id"]
|
||||
)
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
|
|
|
@ -2432,6 +2432,26 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
self._async_current_entries(include_ignore=False), match_dict
|
||||
)
|
||||
|
||||
@callback
|
||||
def _abort_if_unique_id_mismatch(
|
||||
self,
|
||||
*,
|
||||
reason: str = "unique_id_mismatch",
|
||||
) -> None:
|
||||
"""Abort if the unique ID does not match the reauth/reconfigure context.
|
||||
|
||||
Requires strings.json entry corresponding to the `reason` parameter
|
||||
in user visible flows.
|
||||
"""
|
||||
if (
|
||||
self.source == SOURCE_REAUTH
|
||||
and self._get_reauth_entry().unique_id != self.unique_id
|
||||
) or (
|
||||
self.source == SOURCE_RECONFIGURE
|
||||
and self._get_reconfigure_entry().unique_id != self.unique_id
|
||||
):
|
||||
raise data_entry_flow.AbortFlow(reason)
|
||||
|
||||
@callback
|
||||
def _abort_if_unique_id_configured(
|
||||
self,
|
||||
|
|
|
@ -6677,6 +6677,73 @@ async def test_reauth_helper_alignment(
|
|||
assert helper_flow_init_data == reauth_flow_init_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("original_unique_id", "new_unique_id", "reason"),
|
||||
[
|
||||
("unique", "unique", "success"),
|
||||
(None, None, "success"),
|
||||
("unique", "new", "unique_id_mismatch"),
|
||||
("unique", None, "unique_id_mismatch"),
|
||||
(None, "new", "unique_id_mismatch"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"source",
|
||||
[config_entries.SOURCE_REAUTH, config_entries.SOURCE_RECONFIGURE],
|
||||
)
|
||||
async def test_abort_if_unique_id_mismatch(
|
||||
hass: HomeAssistant,
|
||||
source: str,
|
||||
original_unique_id: str | None,
|
||||
new_unique_id: str | None,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Test to check if_unique_id_mismatch behavior."""
|
||||
entry = MockConfigEntry(
|
||||
title="From config flow",
|
||||
domain="test",
|
||||
entry_id="01J915Q6T9F6G5V0QJX6HBC94T",
|
||||
data={"host": "any", "port": 123},
|
||||
unique_id=original_unique_id,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
mock_setup_entry = AsyncMock(return_value=True)
|
||||
|
||||
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Test user step."""
|
||||
return await self._async_step_confirm()
|
||||
|
||||
async def async_step_reauth(self, entry_data):
|
||||
"""Test reauth step."""
|
||||
return await self._async_step_confirm()
|
||||
|
||||
async def async_step_reconfigure(self, user_input=None):
|
||||
"""Test reauth step."""
|
||||
return await self._async_step_confirm()
|
||||
|
||||
async def _async_step_confirm(self):
|
||||
"""Confirm input."""
|
||||
await self.async_set_unique_id(new_unique_id)
|
||||
self._abort_if_unique_id_mismatch()
|
||||
return self.async_abort(reason="success")
|
||||
|
||||
with mock_config_flow("test", TestFlow):
|
||||
if source == config_entries.SOURCE_REAUTH:
|
||||
result = await entry.start_reauth_flow(hass)
|
||||
elif source == config_entries.SOURCE_RECONFIGURE:
|
||||
result = await entry.start_reconfigure_flow(hass)
|
||||
await hass.async_block_till_done()
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == reason
|
||||
|
||||
|
||||
def test_state_not_stored_in_storage() -> None:
|
||||
"""Test that state is not stored in storage.
|
||||
|
||||
|
|
Loading…
Reference in New Issue