Reauth flow for Risco cloud (#81264)

* Risco reauth flow

* Address code review comments

* Remove redundant log
pull/81365/head
On Freund 2022-11-01 00:01:22 +02:00 committed by GitHub
parent 4a9859bf54
commit f8de4c3931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 9 deletions

View File

@ -26,7 +26,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
@ -127,10 +127,9 @@ async def _async_setup_cloud_entry(hass: HomeAssistant, entry: ConfigEntry) -> b
try: try:
await risco.login(async_get_clientsession(hass)) await risco.login(async_get_clientsession(hass))
except CannotConnectError as error: except CannotConnectError as error:
raise ConfigEntryNotReady() from error raise ConfigEntryNotReady from error
except UnauthorizedError: except UnauthorizedError as error:
_LOGGER.exception("Failed to login to Risco cloud") raise ConfigEntryAuthFailed from error
return False
scan_interval = entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) scan_interval = entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
coordinator = RiscoDataUpdateCoordinator(hass, risco, scan_interval) coordinator = RiscoDataUpdateCoordinator(hass, risco, scan_interval)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from typing import Any
from pyrisco import CannotConnectError, RiscoCloud, RiscoLocal, UnauthorizedError from pyrisco import CannotConnectError, RiscoCloud, RiscoLocal, UnauthorizedError
import voluptuous as vol import voluptuous as vol
@ -21,6 +22,7 @@ from homeassistant.const import (
STATE_ALARM_ARMED_HOME, STATE_ALARM_ARMED_HOME,
STATE_ALARM_ARMED_NIGHT, STATE_ALARM_ARMED_NIGHT,
) )
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import ( from .const import (
@ -93,6 +95,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
def __init__(self) -> None:
"""Init the config flow."""
self._reauth_entry: config_entries.ConfigEntry | None = None
@staticmethod @staticmethod
@core.callback @core.callback
def async_get_options_flow( def async_get_options_flow(
@ -112,8 +118,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Configure a cloud based alarm.""" """Configure a cloud based alarm."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
await self.async_set_unique_id(user_input[CONF_USERNAME]) if not self._reauth_entry:
self._abort_if_unique_id_configured() await self.async_set_unique_id(user_input[CONF_USERNAME])
self._abort_if_unique_id_configured()
try: try:
info = await validate_cloud_input(self.hass, user_input) info = await validate_cloud_input(self.hass, user_input)
@ -125,12 +132,25 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
return self.async_create_entry(title=info["title"], data=user_input) if not self._reauth_entry:
return self.async_create_entry(title=info["title"], data=user_input)
self.hass.config_entries.async_update_entry(
self._reauth_entry,
data=user_input,
unique_id=user_input[CONF_USERNAME],
)
await self.hass.config_entries.async_reload(self._reauth_entry.entry_id)
return self.async_abort(reason="reauth_successful")
return self.async_show_form( return self.async_show_form(
step_id="cloud", data_schema=CLOUD_SCHEMA, errors=errors step_id="cloud", data_schema=CLOUD_SCHEMA, errors=errors
) )
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle configuration by re-auth."""
self._reauth_entry = await self.async_set_unique_id(entry_data[CONF_USERNAME])
return await self.async_step_cloud()
async def async_step_local(self, user_input=None): async def async_step_local(self, user_input=None):
"""Configure a local based alarm.""" """Configure a local based alarm."""
errors = {} errors = {}

View File

@ -70,7 +70,10 @@ def events():
def cloud_config_entry(hass, options): def cloud_config_entry(hass, options):
"""Fixture for a cloud config entry.""" """Fixture for a cloud config entry."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
domain=DOMAIN, data=TEST_CLOUD_CONFIG, options=options domain=DOMAIN,
data=TEST_CLOUD_CONFIG,
options=options,
unique_id=TEST_CLOUD_CONFIG[CONF_USERNAME],
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
return config_entry return config_entry

View File

@ -10,6 +10,7 @@ from homeassistant.components.risco.config_flow import (
UnauthorizedError, UnauthorizedError,
) )
from homeassistant.components.risco.const import DOMAIN from homeassistant.components.risco.const import DOMAIN
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -142,6 +143,75 @@ async def test_form_cloud_already_exists(hass):
assert result3["reason"] == "already_configured" assert result3["reason"] == "already_configured"
async def test_form_reauth(hass, cloud_config_entry):
"""Test reauthenticate."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH},
data=cloud_config_entry.data,
)
assert result["type"] == "form"
assert result["errors"] == {}
with patch(
"homeassistant.components.risco.config_flow.RiscoCloud.login",
return_value=True,
), patch(
"homeassistant.components.risco.config_flow.RiscoCloud.site_name",
new_callable=PropertyMock(return_value=TEST_SITE_NAME),
), patch(
"homeassistant.components.risco.config_flow.RiscoCloud.close"
), patch(
"homeassistant.components.risco.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {**TEST_CLOUD_DATA, CONF_PASSWORD: "new_password"}
)
await hass.async_block_till_done()
assert result2["type"] == "abort"
assert result2["reason"] == "reauth_successful"
assert cloud_config_entry.data[CONF_PASSWORD] == "new_password"
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_reauth_with_new_username(hass, cloud_config_entry):
"""Test reauthenticate with new username."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH},
data=cloud_config_entry.data,
)
assert result["type"] == "form"
assert result["errors"] == {}
with patch(
"homeassistant.components.risco.config_flow.RiscoCloud.login",
return_value=True,
), patch(
"homeassistant.components.risco.config_flow.RiscoCloud.site_name",
new_callable=PropertyMock(return_value=TEST_SITE_NAME),
), patch(
"homeassistant.components.risco.config_flow.RiscoCloud.close"
), patch(
"homeassistant.components.risco.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {**TEST_CLOUD_DATA, CONF_USERNAME: "new_user"}
)
await hass.async_block_till_done()
assert result2["type"] == "abort"
assert result2["reason"] == "reauth_successful"
assert cloud_config_entry.data[CONF_USERNAME] == "new_user"
assert cloud_config_entry.unique_id == "new_user"
assert len(mock_setup_entry.mock_calls) == 1
async def test_local_form(hass): async def test_local_form(hass):
"""Test we get the local form.""" """Test we get the local form."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(