diff --git a/homeassistant/components/schlage/__init__.py b/homeassistant/components/schlage/__init__.py index feaa95864d5..96ff32d3e85 100644 --- a/homeassistant/components/schlage/__init__.py +++ b/homeassistant/components/schlage/__init__.py @@ -7,8 +7,9 @@ import pyschlage from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed -from .const import DOMAIN, LOGGER +from .const import DOMAIN from .coordinator import SchlageDataUpdateCoordinator PLATFORMS: list[Platform] = [ @@ -26,8 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: try: auth = await hass.async_add_executor_job(pyschlage.Auth, username, password) except WarrantException as ex: - LOGGER.error("Schlage authentication failed: %s", ex) - return False + raise ConfigEntryAuthFailed from ex coordinator = SchlageDataUpdateCoordinator(hass, username, pyschlage.Schlage(auth)) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator diff --git a/homeassistant/components/schlage/config_flow.py b/homeassistant/components/schlage/config_flow.py index 7e095466087..84bc3ef8ef6 100644 --- a/homeassistant/components/schlage/config_flow.py +++ b/homeassistant/components/schlage/config_flow.py @@ -1,6 +1,7 @@ """Config flow for Schlage integration.""" from __future__ import annotations +from collections.abc import Mapping from typing import Any import pyschlage @@ -8,6 +9,7 @@ from pyschlage.exceptions import NotAuthorizedError import voluptuous as vol from homeassistant import config_entries +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.data_entry_flow import FlowResult @@ -16,6 +18,7 @@ from .const import DOMAIN, LOGGER STEP_USER_DATA_SCHEMA = vol.Schema( {vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str} ) +STEP_REAUTH_DATA_SCHEMA = vol.Schema({vol.Required(CONF_PASSWORD): str}) class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): @@ -23,36 +26,88 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 + reauth_entry: ConfigEntry | None = None + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle the initial step.""" - errors: dict[str, str] = {} - if user_input is not None: - username = user_input[CONF_USERNAME] - password = user_input[CONF_PASSWORD] - try: - user_id = await self.hass.async_add_executor_job( - _authenticate, username, password - ) - except NotAuthorizedError: - errors["base"] = "invalid_auth" - except Exception: # pylint: disable=broad-except - LOGGER.exception("Unknown error") - errors["base"] = "unknown" - else: - await self.async_set_unique_id(user_id) - return self.async_create_entry(title=username, data=user_input) + if user_input is None: + return self._show_user_form({}) + username = user_input[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + user_id, errors = await self.hass.async_add_executor_job( + _authenticate, username, password + ) + if user_id is None: + return self._show_user_form(errors) + await self.async_set_unique_id(user_id) + return self.async_create_entry(title=username, data=user_input) + + def _show_user_form(self, errors: dict[str, str]) -> FlowResult: + """Show the user form.""" return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors ) + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Handle reauth upon an API authentication error.""" + self.reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + return await self.async_step_reauth_confirm() -def _authenticate(username: str, password: str) -> str: + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Dialog that informs the user that reauth is required.""" + assert self.reauth_entry is not None + if user_input is None: + return self._show_reauth_form({}) + + username = self.reauth_entry.data[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + user_id, errors = await self.hass.async_add_executor_job( + _authenticate, username, password + ) + if user_id is None: + return self._show_reauth_form(errors) + + if self.reauth_entry.unique_id != user_id: + return self.async_abort(reason="wrong_account") + + data = { + CONF_USERNAME: username, + CONF_PASSWORD: user_input[CONF_PASSWORD], + } + self.hass.config_entries.async_update_entry(self.reauth_entry, data=data) + await self.hass.config_entries.async_reload(self.reauth_entry.entry_id) + return self.async_abort(reason="reauth_successful") + + def _show_reauth_form(self, errors: dict[str, str]) -> FlowResult: + """Show the reauth form.""" + return self.async_show_form( + step_id="reauth_confirm", + data_schema=STEP_REAUTH_DATA_SCHEMA, + errors=errors, + ) + + +def _authenticate(username: str, password: str) -> tuple[str | None, dict[str, str]]: """Authenticate with the Schlage API.""" - auth = pyschlage.Auth(username, password) - auth.authenticate() - # The user_id property will make a blocking call if it's not already - # cached. To avoid blocking the event loop, we read it here. - return auth.user_id + user_id = None + errors: dict[str, str] = {} + try: + auth = pyschlage.Auth(username, password) + auth.authenticate() + except NotAuthorizedError: + errors["base"] = "invalid_auth" + except Exception: # pylint: disable=broad-except + LOGGER.exception("Unknown error") + errors["base"] = "unknown" + else: + # The user_id property will make a blocking call if it's not already + # cached. To avoid blocking the event loop, we read it here. + user_id = auth.user_id + return user_id, errors diff --git a/homeassistant/components/schlage/coordinator.py b/homeassistant/components/schlage/coordinator.py index 2b1e8460af2..3d736306d91 100644 --- a/homeassistant/components/schlage/coordinator.py +++ b/homeassistant/components/schlage/coordinator.py @@ -5,10 +5,11 @@ import asyncio from dataclasses import dataclass from pyschlage import Lock, Schlage -from pyschlage.exceptions import Error as SchlageError +from pyschlage.exceptions import Error as SchlageError, NotAuthorizedError from pyschlage.log import LockLog from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import DOMAIN, LOGGER, UPDATE_INTERVAL @@ -43,6 +44,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]): """Fetch the latest data from the Schlage API.""" try: locks = await self.hass.async_add_executor_job(self.api.locks) + except NotAuthorizedError as ex: + raise ConfigEntryAuthFailed from ex except SchlageError as ex: raise UpdateFailed("Failed to refresh Schlage data") from ex lock_data = await asyncio.gather( @@ -64,6 +67,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]): logs = previous_lock_data.logs try: logs = lock.logs() + except NotAuthorizedError as ex: + raise ConfigEntryAuthFailed from ex except SchlageError as ex: LOGGER.debug('Failed to read logs for lock "%s": %s', lock.name, ex) diff --git a/homeassistant/components/schlage/strings.json b/homeassistant/components/schlage/strings.json index 076ed97e298..721d9e80286 100644 --- a/homeassistant/components/schlage/strings.json +++ b/homeassistant/components/schlage/strings.json @@ -6,6 +6,13 @@ "username": "[%key:common::config_flow::data::username%]", "password": "[%key:common::config_flow::data::password%]" } + }, + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "description": "The Schlage integration needs to re-authenticate your account", + "data": { + "password": "[%key:common::config_flow::data::password%]" + } } }, "error": { @@ -13,7 +20,9 @@ "unknown": "[%key:common::config_flow::error::unknown%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", + "wrong_account": "The user credentials provided do not match this Schlage account." } }, "entity": { diff --git a/tests/components/schlage/conftest.py b/tests/components/schlage/conftest.py index 7b610a6b4da..5f9676b7d09 100644 --- a/tests/components/schlage/conftest.py +++ b/tests/components/schlage/conftest.py @@ -54,14 +54,14 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]: @pytest.fixture -def mock_schlage(): +def mock_schlage() -> Mock: """Mock pyschlage.Schlage.""" with patch("pyschlage.Schlage", autospec=True) as mock_schlage: yield mock_schlage.return_value @pytest.fixture -def mock_pyschlage_auth(): +def mock_pyschlage_auth() -> Mock: """Mock pyschlage.Auth.""" with patch("pyschlage.Auth", autospec=True) as mock_auth: mock_auth.return_value.user_id = "abc123" @@ -69,7 +69,7 @@ def mock_pyschlage_auth(): @pytest.fixture -def mock_lock(): +def mock_lock() -> Mock: """Mock Lock fixture.""" mock_lock = create_autospec(Lock) mock_lock.configure_mock( diff --git a/tests/components/schlage/test_config_flow.py b/tests/components/schlage/test_config_flow.py index b256e8950ed..14121f5d9ca 100644 --- a/tests/components/schlage/test_config_flow.py +++ b/tests/components/schlage/test_config_flow.py @@ -9,6 +9,8 @@ from homeassistant.components.schlage.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from tests.common import MockConfigEntry + pytestmark = pytest.mark.usefixtures("mock_setup_entry") @@ -78,3 +80,94 @@ async def test_form_unknown(hass: HomeAssistant, mock_pyschlage_auth: Mock) -> N assert result2["type"] == FlowResultType.FORM assert result2["errors"] == {"base": "unknown"} + + +async def test_reauth( + hass: HomeAssistant, + mock_added_config_entry: MockConfigEntry, + mock_setup_entry: AsyncMock, + mock_pyschlage_auth: Mock, +) -> None: + """Test reauth flow.""" + mock_added_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + [result] = flows + assert result["step_id"] == "reauth_confirm" + + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"password": "new-password"}, + ) + await hass.async_block_till_done() + + mock_pyschlage_auth.authenticate.assert_called_once_with() + assert result2["type"] == FlowResultType.ABORT + assert result2["reason"] == "reauth_successful" + assert mock_added_config_entry.data == { + "username": "asdf@asdf.com", + "password": "new-password", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_reauth_invalid_auth( + hass: HomeAssistant, + mock_added_config_entry: MockConfigEntry, + mock_setup_entry: AsyncMock, + mock_pyschlage_auth: Mock, +) -> None: + """Test reauth flow.""" + mock_added_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + [result] = flows + assert result["step_id"] == "reauth_confirm" + + mock_pyschlage_auth.authenticate.reset_mock() + mock_pyschlage_auth.authenticate.side_effect = NotAuthorizedError + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"password": "new-password"}, + ) + await hass.async_block_till_done() + + mock_pyschlage_auth.authenticate.assert_called_once_with() + assert result2["type"] == FlowResultType.FORM + assert result2["errors"] == {"base": "invalid_auth"} + + +async def test_reauth_wrong_account( + hass: HomeAssistant, + mock_added_config_entry: MockConfigEntry, + mock_setup_entry: AsyncMock, + mock_pyschlage_auth: Mock, +) -> None: + """Test reauth flow.""" + mock_pyschlage_auth.user_id = "bad-user-id" + mock_added_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + [result] = flows + assert result["step_id"] == "reauth_confirm" + + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"password": "new-password"}, + ) + await hass.async_block_till_done() + + mock_pyschlage_auth.authenticate.assert_called_once_with() + assert result2["type"] == FlowResultType.ABORT + assert result2["reason"] == "wrong_account" + assert mock_added_config_entry.data == { + "username": "asdf@asdf.com", + "password": "hunter2", + } + assert len(mock_setup_entry.mock_calls) == 1 diff --git a/tests/components/schlage/test_init.py b/tests/components/schlage/test_init.py index 0811d87ec80..0fe7af1982b 100644 --- a/tests/components/schlage/test_init.py +++ b/tests/components/schlage/test_init.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from pycognito.exceptions import WarrantException -from pyschlage.exceptions import Error +from pyschlage.exceptions import Error, NotAuthorizedError from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant @@ -43,6 +43,41 @@ async def test_update_data_fails( assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY +async def test_update_data_auth_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_pyschlage_auth: Mock, + mock_schlage: Mock, +) -> None: + """Test that we properly handle API errors.""" + mock_schlage.locks.side_effect = NotAuthorizedError + mock_config_entry.add_to_hass(hass) + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_schlage.locks.call_count == 1 + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + + +async def test_update_data_get_logs_auth_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_pyschlage_auth: Mock, + mock_schlage: Mock, + mock_lock: Mock, +) -> None: + """Test that we properly handle API errors.""" + mock_schlage.locks.return_value = [mock_lock] + mock_lock.logs.reset_mock() + mock_lock.logs.side_effect = NotAuthorizedError + mock_config_entry.add_to_hass(hass) + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_schlage.locks.call_count == 1 + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + + async def test_load_unload_config_entry( hass: HomeAssistant, mock_config_entry: MockConfigEntry,